From 8ac8c13e0067a55673126470b7ecdf1f25434bf9 Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Mon, 2 Dec 2024 19:22:25 -0800 Subject: [PATCH 01/30] Remove unused scope --- oauth2/oauth2.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oauth2/oauth2.go b/oauth2/oauth2.go index 6f232275..1e5c9df3 100644 --- a/oauth2/oauth2.go +++ b/oauth2/oauth2.go @@ -32,7 +32,7 @@ func DiscoverConfig(ctx context.Context, domain, clientID string) (*oauth2.Confi cfg := oauth2.Config{ ClientID: clientID, Endpoint: provider.Endpoint(), - Scopes: []string{"openid", "profile", "okta.apps.read", "okta.apps.sso"}, + Scopes: []string{"openid", "profile", "okta.apps.sso"}, } return &cfg, nil From 34e2b8bfafc3f456107ffb8847db6fd06737d25e Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Mon, 2 Dec 2024 19:43:10 -0800 Subject: [PATCH 02/30] Use oauth2 package to generate verifier --- command/login.go | 2 +- go.mod | 2 +- go.sum | 2 ++ oauth2/oauth2.go | 26 ++++---------------------- 4 files changed, 8 insertions(+), 24 deletions(-) diff --git a/command/login.go b/command/login.go index b6c69233..ae2135fe 100644 --- a/command/login.go +++ b/command/login.go @@ -100,7 +100,7 @@ func (c LoginCommand) Execute(ctx context.Context, config *Config) error { } } - accessToken, err := handler.HandlePendingSession(ctx, sock, oauth2.GeneratePkceChallenge(), oauth2.GenerateState()) + accessToken, err := handler.HandlePendingSession(ctx, sock, oauth2.GenerateState()) if err != nil { return err } diff --git a/go.mod b/go.mod index a3748bb2..9903ff09 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,7 @@ require ( github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.8.4 golang.org/x/net v0.25.0 - golang.org/x/oauth2 v0.6.0 + golang.org/x/oauth2 v0.24.0 ) require ( diff --git a/go.sum b/go.sum index 90f06203..441bc250 100644 --- a/go.sum +++ b/go.sum @@ -156,6 +156,8 @@ golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/oauth2 v0.6.0 h1:Lh8GPgSKBfWSwFvtuWOfeI3aAAnbXTSutYxJiOJFgIw= golang.org/x/oauth2 v0.6.0/go.mod h1:ycmewcwgD4Rpr3eZJLSB4Kyyljb3qDh40vJ8STE5HKw= +golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE= +golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/oauth2/oauth2.go b/oauth2/oauth2.go index 1e5c9df3..47e32027 100644 --- a/oauth2/oauth2.go +++ b/oauth2/oauth2.go @@ -3,7 +3,6 @@ package oauth2 import ( "context" "crypto/rand" - "crypto/sha256" "encoding/base64" "errors" "fmt" @@ -116,41 +115,24 @@ func (e OAuth2Error) Error() string { return fmt.Sprintf("oauth2 error: %s (%s)", e.Description, e.Reason) } -func GeneratePkceChallenge() PkceChallenge { - codeVerifierBuf := make([]byte, stateBufSize) - rand.Read(codeVerifierBuf) - codeVerifier := base64.RawURLEncoding.EncodeToString(codeVerifierBuf) - codeChallengeHash := sha256.Sum256([]byte(codeVerifier)) - codeChallenge := base64.RawURLEncoding.EncodeToString(codeChallengeHash[:]) - return PkceChallenge{Verifier: codeVerifier, Challenge: codeChallenge} -} - func GenerateState() string { stateBuf := make([]byte, stateBufSize) rand.Read(stateBuf) return base64.URLEncoding.EncodeToString(stateBuf) } -type PkceChallenge struct { - Challenge string - Verifier string -} - type RedirectionFlowHandler struct { Config *oauth2.Config OnDisplayURL func(url string) error } -func (r RedirectionFlowHandler) HandlePendingSession(ctx context.Context, listener net.Listener, challenge PkceChallenge, state string) (*oauth2.Token, error) { +func (r RedirectionFlowHandler) HandlePendingSession(ctx context.Context, listener net.Listener, state string) (*oauth2.Token, error) { if r.OnDisplayURL == nil { panic("OnDisplayURL must be set") } - url := r.Config.AuthCodeURL(state, - oauth2.SetAuthURLParam("code_challenge_method", "S256"), - oauth2.SetAuthURLParam("code_challenge", challenge.Challenge), - ) - + verifier := oauth2.GenerateVerifier() + url := r.Config.AuthCodeURL(state, oauth2.S256ChallengeOption(verifier)) callbackHandler, ch, cancel := OAuth2CallbackHandler() // TODO: This error probably should not be ignored if it is not http.ErrServerClosed go http.Serve(listener, callbackHandler) @@ -167,7 +149,7 @@ func (r RedirectionFlowHandler) HandlePendingSession(ctx context.Context, listen if err != nil { return nil, fmt.Errorf("failed to get authorization code: %w", err) } - return r.Config.Exchange(ctx, code, oauth2.SetAuthURLParam("code_verifier", challenge.Verifier)) + return r.Config.Exchange(ctx, code, oauth2.VerifierOption(verifier)) case <-ctx.Done(): return nil, ctx.Err() } From 1dab84a641d822afd9e7e95bc6bee015ac1e9d77 Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Mon, 2 Dec 2024 19:54:20 -0800 Subject: [PATCH 03/30] Exchange the code for a token inside of the callback --- command/login.go | 8 +---- oauth2/oauth2.go | 82 +++++++++++++++++++++++++------------------ oauth2/oauth2_test.go | 62 +++++--------------------------- 3 files changed, 57 insertions(+), 95 deletions(-) diff --git a/command/login.go b/command/login.go index ae2135fe..7a0cb292 100644 --- a/command/login.go +++ b/command/login.go @@ -100,17 +100,11 @@ func (c LoginCommand) Execute(ctx context.Context, config *Config) error { } } - accessToken, err := handler.HandlePendingSession(ctx, sock, oauth2.GenerateState()) + accessToken, idToken, err := handler.HandlePendingSession(ctx, sock, oauth2.GenerateState()) if err != nil { return err } - // https://openid.net/specs/openid-connect-core-1_0.html#TokenResponse - idToken, ok := accessToken.Extra("id_token").(string) - if !ok { - return fmt.Errorf("id_token not found in token response") - } - return config.SaveOAuthToken(accessToken, idToken) } diff --git a/oauth2/oauth2.go b/oauth2/oauth2.go index 47e32027..57a55330 100644 --- a/oauth2/oauth2.go +++ b/oauth2/oauth2.go @@ -9,7 +9,6 @@ import ( "net" "net/http" "strings" - "sync" "github.com/RobotsAndPencils/go-saml" "github.com/coreos/go-oidc" @@ -70,40 +69,53 @@ func (o OAuth2CallbackState) Verify(expectedState string) (string, error) { return o.code, nil } +type Callback struct { + Token *oauth2.Token + IDToken *string + Error error +} + +type CodeExchanger interface { + Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) +} + // OAuth2CallbackHandler returns a http.Handler, channel and function triple. // // The http handler will accept exactly one request, which it will assume is an OAuth2 callback, parse it into an OAuth2CallbackState and then provide it to the given channel. Subsequent requests will be silently ignored. // // The function may be called to ensure that the channel is closed. The channel is closed when a request is received. In general, it is a good idea to ensure this function is called in a defer() block. -func OAuth2CallbackHandler() (http.Handler, <-chan OAuth2CallbackState, func()) { - // TODO: It is possible for the caller to close a panic() if they execute the function in the triplet while the handler has not yet received a request. - // That caller is us, so I don't care that much, but that probably indicates that this design is smelly. - // - // We should look at the Go SDK to see how they handle similar cases - channels that are not bound by a timer, or similar. - - ch := make(chan OAuth2CallbackState, 1) - var reqHandle, closeHandle sync.Once - closeFn := func() { - closeHandle.Do(func() { - close(ch) - }) - } - +func OAuth2CallbackHandler(codeEx CodeExchanger, state, verifier string, ch chan<- Callback) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { // This can sometimes be called multiple times, depending on the browser. // We will simply ignore any other requests and only serve the first. - reqHandle.Do(func() { - var state OAuth2CallbackState - state.FromRequest(r) - ch <- state - closeFn() - }) - - // We still want to provide feedback to the end-user. + var info OAuth2CallbackState + info.FromRequest(r) + + code, err := info.Verify(state) + if err != nil { + ch <- Callback{Error: err} + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + token, err := codeEx.Exchange(r.Context(), code, oauth2.VerifierOption(verifier)) + if err != nil { + ch <- Callback{Error: err} + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // https://openid.net/specs/openid-connect-core-1_0.html#TokenResponse + if idToken, ok := token.Extra("id_token").(string); ok { + ch <- Callback{Token: token, IDToken: &idToken} + } else { + ch <- Callback{Token: token} + } + fmt.Fprintln(w, "You may close this window now.") } - return http.HandlerFunc(fn), ch, closeFn + return http.HandlerFunc(fn) } type OAuth2Error struct { @@ -126,32 +138,32 @@ type RedirectionFlowHandler struct { OnDisplayURL func(url string) error } -func (r RedirectionFlowHandler) HandlePendingSession(ctx context.Context, listener net.Listener, state string) (*oauth2.Token, error) { +func (r RedirectionFlowHandler) HandlePendingSession(ctx context.Context, listener net.Listener, state string) (*oauth2.Token, string, error) { if r.OnDisplayURL == nil { panic("OnDisplayURL must be set") } verifier := oauth2.GenerateVerifier() url := r.Config.AuthCodeURL(state, oauth2.S256ChallengeOption(verifier)) - callbackHandler, ch, cancel := OAuth2CallbackHandler() + + ch := make(chan Callback, 1) // TODO: This error probably should not be ignored if it is not http.ErrServerClosed - go http.Serve(listener, callbackHandler) - defer cancel() + go http.Serve(listener, OAuth2CallbackHandler(r.Config, state, verifier, ch)) if err := r.OnDisplayURL(url); err != nil { - // This is unlikely to ever happen - return nil, fmt.Errorf("failed to display link: %w", err) + return nil, "", fmt.Errorf("failed to display link: %w", err) } select { case info := <-ch: - code, err := info.Verify(state) - if err != nil { - return nil, fmt.Errorf("failed to get authorization code: %w", err) + // TODO: Close the server immediately to prevent any more requests being received. + if info.Error != nil { + return nil, "", info.Error } - return r.Config.Exchange(ctx, code, oauth2.VerifierOption(verifier)) + + return info.Token, "", nil case <-ctx.Done(): - return nil, ctx.Err() + return nil, "", ctx.Err() } } diff --git a/oauth2/oauth2_test.go b/oauth2/oauth2_test.go index f5d2ef2d..400007cc 100644 --- a/oauth2/oauth2_test.go +++ b/oauth2/oauth2_test.go @@ -1,12 +1,13 @@ package oauth2 import ( + "context" "net/http" "net/http/httptest" "net/url" "testing" - "github.com/stretchr/testify/assert" + "golang.org/x/oauth2" ) func sendOAuth2CallbackRequest(handler http.Handler, values url.Values) { @@ -22,59 +23,18 @@ func sendOAuth2CallbackRequest(handler http.Handler, values url.Values) { handler.ServeHTTP(w, req) } -func Test_OAuth2CallbackHandler_YieldsCorrectlyFormattedState(t *testing.T) { - handler, ch, cancel := OAuth2CallbackHandler() - t.Cleanup(func() { - cancel() - }) - - expectedState := "state goes here" - expectedCode := "code goes here" +type testCodeExchanger struct{} - go sendOAuth2CallbackRequest(handler, url.Values{ - "code": []string{expectedCode}, - "state": []string{expectedState}, - }) - - callbackState := <-ch - code, err := callbackState.Verify(expectedState) - assert.NoError(t, err) - assert.Equal(t, expectedCode, code) -} - -func Test_OAuth2CallbackState_VerifyWorksCorrectly(t *testing.T) { - t.Run("happy path", func(t *testing.T) { - expectedState := "state goes here" - expectedCode := "code goes here" - callbackState := OAuth2CallbackState{ - code: expectedCode, - state: expectedState, - } - code, err := callbackState.Verify(expectedState) - assert.NoError(t, err) - assert.Equal(t, expectedCode, code) - }) - - t.Run("unhappy path", func(t *testing.T) { - expectedState := "state goes here" - expectedCode := "code goes here" - callbackState := OAuth2CallbackState{ - code: expectedCode, - state: expectedState, - } - _, err := callbackState.Verify("mismatching state") - var oauthErr OAuth2Error - assert.ErrorAs(t, err, &oauthErr) - assert.Equal(t, "invalid_state", oauthErr.Reason) - }) +func (t *testCodeExchanger) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return nil, nil } // Test_OAuth2Listener_MultipleRequestsDoesNotCausePanic prevents an issue where OAuth2Listener would send a request to a closed channel func Test_OAuth2Listener_MultipleRequestsDoesNotCausePanic(t *testing.T) { - handler, ch, cancel := OAuth2CallbackHandler() - t.Cleanup(func() { - cancel() - }) + ch := make(chan Callback, 2) + defer close(ch) + + handler := OAuth2CallbackHandler(&testCodeExchanger{}, "state", "verifier", ch) go sendOAuth2CallbackRequest(handler, url.Values{ // We send empty values because we don't care about processing in this test @@ -82,10 +42,6 @@ func Test_OAuth2Listener_MultipleRequestsDoesNotCausePanic(t *testing.T) { "state": []string{""}, }) - // We drain the channel of the first request so the handler completes. - // Without this step, we would get 'stuck' in the sync.Once(). - <-ch - // We send this request synchronously to ensure that any panics are caught during the test. sendOAuth2CallbackRequest(handler, url.Values{ "code": []string{"not the expected code and should be discarded"}, From 56dff97a538bae614b63ddfe1b6d6ee7bd5615e1 Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Tue, 12 Nov 2024 08:26:57 -0800 Subject: [PATCH 04/30] Get Kong --- go.mod | 1 + go.sum | 2 ++ 2 files changed, 3 insertions(+) diff --git a/go.mod b/go.mod index 9903ff09..d8cb376a 100644 --- a/go.mod +++ b/go.mod @@ -22,6 +22,7 @@ require ( ) require ( + github.com/alecthomas/kong v1.4.0 // indirect github.com/aws/aws-sdk-go-v2/credentials v1.17.44 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.19 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.23 // indirect diff --git a/go.sum b/go.sum index 441bc250..20c27d1c 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/RobotsAndPencils/go-saml v0.0.0-20170520135329-fb13cb52a46b h1:EgJ6N2S0h1WfFIjU5/VVHWbMSVYXAluop97Qxpr/lfQ= github.com/RobotsAndPencils/go-saml v0.0.0-20170520135329-fb13cb52a46b/go.mod h1:3SAoF0F5EbcOuBD5WT9nYkbIJieBS84cUQXADbXeBsU= +github.com/alecthomas/kong v1.4.0 h1:UL7tzGMnnY0YRMMvJyITIRX1EpO6RbBRZDNcCevy3HA= +github.com/alecthomas/kong v1.4.0/go.mod h1:p2vqieVMeTAnaC83txKtXe8FLke2X07aruPWXyMPQrU= github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= github.com/aws/aws-lambda-go v1.47.0 h1:0H8s0vumYx/YKs4sE7YM0ktwL2eWse+kfopsRI1sXVI= github.com/aws/aws-lambda-go v1.47.0/go.mod h1:dpMpZgvWx5vuQJfBt0zqBha60q7Dd7RfgJv23DymV8A= From f0bfb01a002733d39c194529d905754ef983a041 Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Tue, 12 Nov 2024 08:42:16 -0800 Subject: [PATCH 05/30] Remove Config and Args from the struct --- command/get.go | 36 ++++++++++++++++-------------------- command/login.go | 16 +++++++++------- command/root.go | 21 +++++++++++++++++++-- 3 files changed, 44 insertions(+), 29 deletions(-) diff --git a/command/get.go b/command/get.go index 751118cb..ab0e3837 100644 --- a/command/get.go +++ b/command/get.go @@ -55,7 +55,7 @@ func resolveApplicationInfo(cfg *Config, bypassCache bool, nameOrID string) (*Ac } type GetCommand struct { - AccountIDOrName string + AccountNameOrID string `arg:""` TimeToLive uint TimeRemaining uint OutputType, ShellType, RoleName, AWSCLIPath, OIDCDomain, ClientID, Region string @@ -86,7 +86,7 @@ func (g *GetCommand) Parse(cmd *cobra.Command, args []string) error { if len(args) == 0 { return fmt.Errorf("account name or alias is required") } - g.AccountIDOrName = args[0] + g.AccountNameOrID = args[0] return nil } @@ -105,8 +105,8 @@ func (g GetCommand) printUsage() error { return g.UsageFunc() } -func (g GetCommand) Execute(ctx context.Context, config *Config) error { - if HasTokenExpired(config.Tokens) { +func (g GetCommand) Execute(ctx context.Context, cfg *Config) error { + if HasTokenExpired(cfg.Tokens) { if !g.Login { return ErrTokensExpiredOrAbsent } @@ -118,24 +118,24 @@ func (g GetCommand) Execute(ctx context.Context, config *Config) error { NoBrowser: g.NoBrowser, } - if err := loginCommand.Execute(ctx, config); err != nil { + if err := loginCommand.Execute(ctx, cfg); err != nil { return err } } var accountID string - if g.AccountIDOrName != "" { - accountID = g.AccountIDOrName - } else if config.LastUsedAccount != nil { + if g.AccountNameOrID != "" { + accountID = g.AccountNameOrID + } else if cfg.LastUsedAccount != nil { // No account specified. Can we use the most recent one? - accountID = *config.LastUsedAccount + accountID = *cfg.LastUsedAccount } else { return g.printUsage() } - account, ok := resolveApplicationInfo(config, g.BypassCache, accountID) + account, ok := resolveApplicationInfo(cfg, g.BypassCache, accountID) if !ok { - return UnknownAccountError(g.AccountIDOrName, FlagBypassCache) + return UnknownAccountError(g.AccountNameOrID, FlagBypassCache) } if g.RoleName == "" { @@ -146,13 +146,13 @@ func (g GetCommand) Execute(ctx context.Context, config *Config) error { g.RoleName = account.MostRecentRole } - if config.TimeRemaining != 0 && g.TimeRemaining == DefaultTimeRemaining { - g.TimeRemaining = config.TimeRemaining + if cfg.TimeRemaining != 0 && g.TimeRemaining == DefaultTimeRemaining { + g.TimeRemaining = cfg.TimeRemaining } credentials := LoadAWSCredentialsFromEnvironment() if !credentials.ValidUntil(account, time.Duration(g.TimeRemaining)*time.Minute) { - newCredentials, err := g.fetchNewCredentials(ctx, *account, config) + newCredentials, err := g.fetchNewCredentials(ctx, *account, cfg) if err != nil { return err } @@ -163,7 +163,7 @@ func (g GetCommand) Execute(ctx context.Context, config *Config) error { account.MostRecentRole = g.RoleName } - config.LastUsedAccount = &accountID + cfg.LastUsedAccount = &accountID return echoCredentials(accountID, accountID, credentials, g.OutputType, g.ShellType, g.AWSCLIPath) } @@ -175,7 +175,7 @@ func (g GetCommand) fetchNewCredentials(ctx context.Context, account Account, cf pair, ok := findRoleInSAML(g.RoleName, samlResponse) if !ok { - return nil, UnknownRoleError(g.RoleName, g.AccountIDOrName) + return nil, UnknownRoleError(g.RoleName, g.AccountNameOrID) } if g.TimeToLive == 1 && cfg.TTL != 0 { @@ -227,10 +227,6 @@ A role must be specified when using this command through the --role flag. You ma return err } - if err := getCmd.Validate(); err != nil { - return err - } - return getCmd.Execute(cmd.Context(), ConfigFromCommand(cmd)) }, } diff --git a/command/login.go b/command/login.go index 7a0cb292..39fda0a8 100644 --- a/command/login.go +++ b/command/login.go @@ -26,9 +26,7 @@ func init() { } var loginCmd = &cobra.Command{ - Use: "login", - Short: "Authenticate with KeyConjurer.", - Long: "Login to KeyConjurer using OAuth2. You will be required to open the URL printed to the console or scan a QR code.", + Use: "login", RunE: func(cmd *cobra.Command, args []string) error { var loginCmd LoginCommand if err := loginCmd.Parse(cmd.Flags(), args); err != nil { @@ -50,10 +48,14 @@ func ShouldUseMachineOutput(flags *pflag.FlagSet) bool { } type LoginCommand struct { - OIDCDomain string - ClientID string - MachineOutput bool - NoBrowser bool + OIDCDomain string `help:"The domain name of your OIDC server" hidden:""` + ClientID string `help:"The client ID of your OIDC server" hidden:""` + MachineOutput bool `kong:"-"` + NoBrowser bool `kong:"-"` +} + +func (c LoginCommand) Help() string { + return "Login to KeyConjurer using OAuth2. You will be required to open the URL printed to the console or scan a QR code." } func (c *LoginCommand) Parse(flags *pflag.FlagSet, args []string) error { diff --git a/command/root.go b/command/root.go index b39fe717..1d18a83b 100644 --- a/command/root.go +++ b/command/root.go @@ -9,6 +9,7 @@ import ( "runtime" "time" + "github.com/alecthomas/kong" "github.com/coreos/go-oidc" "github.com/mitchellh/go-homedir" "github.com/spf13/cobra" @@ -107,9 +108,25 @@ To get started run the following commands: SilenceUsage: true, } +var CLI struct { + Login LoginCommand `cmd:"" help:"Authenticate with KeyConjurer."` + Get GetCommand `cmd:"" help:"Retrieve temporary cloud credentials."` + // Switch SwitchCommand `cmd:"" help:"Switch between accounts."` +} + func Execute(ctx context.Context, args []string) error { client := &http.Client{Transport: LogRoundTripper{http.DefaultTransport}} ctx = oidc.ClientContext(ctx, client) - rootCmd.SetArgs(args) - return rootCmd.ExecuteContext(ctx) + + k, err := kong.New(&CLI) + if err != nil { + return err + } + + kongCtx, err := k.Parse(args) + if err != nil { + return err + } + + return kongCtx.Run(ctx) } From adc4deb5e476996ee6951d257343fd9dce04dbc9 Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Tue, 12 Nov 2024 08:46:39 -0800 Subject: [PATCH 06/30] WIP: convert to Kong --- command/get.go | 2 +- command/login.go | 20 +------------------- command/root.go | 3 +-- 3 files changed, 3 insertions(+), 22 deletions(-) diff --git a/command/get.go b/command/get.go index ab0e3837..361210a8 100644 --- a/command/get.go +++ b/command/get.go @@ -118,7 +118,7 @@ func (g GetCommand) Execute(ctx context.Context, cfg *Config) error { NoBrowser: g.NoBrowser, } - if err := loginCommand.Execute(ctx, cfg); err != nil { + if err := loginCommand.Run(ctx, cfg); err != nil { return err } } diff --git a/command/login.go b/command/login.go index 39fda0a8..2a0205e0 100644 --- a/command/login.go +++ b/command/login.go @@ -11,7 +11,6 @@ import ( "github.com/pkg/browser" "github.com/riotgames/key-conjurer/oauth2" - "github.com/spf13/cobra" "github.com/spf13/pflag" ) @@ -20,23 +19,6 @@ var ( FlagNoBrowser = "no-browser" ) -func init() { - loginCmd.Flags().BoolP(FlagURLOnly, "u", false, "Print only the URL to visit rather than a user-friendly message") - loginCmd.Flags().BoolP(FlagNoBrowser, "b", false, "Do not open a browser window, printing the URL instead") -} - -var loginCmd = &cobra.Command{ - Use: "login", - RunE: func(cmd *cobra.Command, args []string) error { - var loginCmd LoginCommand - if err := loginCmd.Parse(cmd.Flags(), args); err != nil { - return err - } - - return loginCmd.Execute(cmd.Context(), ConfigFromCommand(cmd)) - }, -} - // ShouldUseMachineOutput indicates whether or not we should write to standard output as if the user is a machine. // // What this means is implementation specific, but this usually indicates the user is trying to use this program in a script and we should avoid user-friendly output messages associated with values a user might find useful. @@ -67,7 +49,7 @@ func (c *LoginCommand) Parse(flags *pflag.FlagSet, args []string) error { return nil } -func (c LoginCommand) Execute(ctx context.Context, config *Config) error { +func (c LoginCommand) Run(ctx context.Context, config *Config) error { if !HasTokenExpired(config.Tokens) { return nil } diff --git a/command/root.go b/command/root.go index 1d18a83b..06590f2b 100644 --- a/command/root.go +++ b/command/root.go @@ -29,7 +29,6 @@ func init() { rootCmd.PersistentFlags().Int(FlagTimeout, 120, "the amount of time in seconds to wait for keyconjurer to respond") rootCmd.PersistentFlags().String(FlagConfigPath, "~/.keyconjurerrc", "path to .keyconjurerrc file") rootCmd.PersistentFlags().Bool(FlagQuiet, false, "tells the CLI to be quiet; stdout will not contain human-readable informational messages") - rootCmd.AddCommand(loginCmd) rootCmd.AddCommand(accountsCmd) rootCmd.AddCommand(getCmd) rootCmd.AddCommand(setCmd) @@ -128,5 +127,5 @@ func Execute(ctx context.Context, args []string) error { return err } - return kongCtx.Run(ctx) + return kongCtx.Run(kong.Bind(ctx)) } From 4fd5981787b6554d18dff1401a35893137b2ce3b Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Tue, 12 Nov 2024 09:07:26 -0800 Subject: [PATCH 07/30] Convert CLI to a struct so we can add hooks to it --- command/root.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/command/root.go b/command/root.go index 06590f2b..de894cae 100644 --- a/command/root.go +++ b/command/root.go @@ -107,7 +107,7 @@ To get started run the following commands: SilenceUsage: true, } -var CLI struct { +type CLI struct { Login LoginCommand `cmd:"" help:"Authenticate with KeyConjurer."` Get GetCommand `cmd:"" help:"Retrieve temporary cloud credentials."` // Switch SwitchCommand `cmd:"" help:"Switch between accounts."` @@ -117,7 +117,8 @@ func Execute(ctx context.Context, args []string) error { client := &http.Client{Transport: LogRoundTripper{http.DefaultTransport}} ctx = oidc.ClientContext(ctx, client) - k, err := kong.New(&CLI) + var cli CLI + k, err := kong.New(&cli) if err != nil { return err } From cec1a308f8e8ae31cfb2efab51b07a3618065d9e Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Tue, 12 Nov 2024 09:09:00 -0800 Subject: [PATCH 08/30] Ignore function args --- command/get.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/command/get.go b/command/get.go index 361210a8..2c76ba93 100644 --- a/command/get.go +++ b/command/get.go @@ -61,8 +61,8 @@ type GetCommand struct { OutputType, ShellType, RoleName, AWSCLIPath, OIDCDomain, ClientID, Region string Login, URLOnly, NoBrowser, BypassCache, MachineOutput bool - UsageFunc func() error - PrintErrln func(...any) + UsageFunc func() error `kong:"-"` + PrintErrln func(...any) `kong:"-"` } func (g *GetCommand) Parse(cmd *cobra.Command, args []string) error { From ef83ff440ed7775bd20e59401eb3e9d78576a59b Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Tue, 12 Nov 2024 09:20:46 -0800 Subject: [PATCH 09/30] Add RunContext() because Kong doesn't support binding context well to Run() --- command/get.go | 10 +++++++--- command/login.go | 11 +++++++++-- command/root.go | 11 +++++------ 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/command/get.go b/command/get.go index 2c76ba93..5be53080 100644 --- a/command/get.go +++ b/command/get.go @@ -105,7 +105,7 @@ func (g GetCommand) printUsage() error { return g.UsageFunc() } -func (g GetCommand) Execute(ctx context.Context, cfg *Config) error { +func (g GetCommand) RunContext(ctx context.Context, cfg *Config) error { if HasTokenExpired(cfg.Tokens) { if !g.Login { return ErrTokensExpiredOrAbsent @@ -118,7 +118,7 @@ func (g GetCommand) Execute(ctx context.Context, cfg *Config) error { NoBrowser: g.NoBrowser, } - if err := loginCommand.Run(ctx, cfg); err != nil { + if err := loginCommand.RunContext(ctx, cfg); err != nil { return err } } @@ -167,6 +167,10 @@ func (g GetCommand) Execute(ctx context.Context, cfg *Config) error { return echoCredentials(accountID, accountID, credentials, g.OutputType, g.ShellType, g.AWSCLIPath) } +func (g GetCommand) Run(cfg *Config) error { + return g.RunContext(context.Background(), cfg) +} + func (g GetCommand) fetchNewCredentials(ctx context.Context, account Account, cfg *Config) (*CloudCredentials, error) { samlResponse, assertionStr, err := oauth2.DiscoverConfigAndExchangeTokenForAssertion(ctx, cfg.Tokens.AccessToken, cfg.Tokens.IDToken, g.OIDCDomain, g.ClientID, account.ID) if err != nil { @@ -227,7 +231,7 @@ A role must be specified when using this command through the --role flag. You ma return err } - return getCmd.Execute(cmd.Context(), ConfigFromCommand(cmd)) + return getCmd.RunContext(cmd.Context(), ConfigFromCommand(cmd)) }, } diff --git a/command/login.go b/command/login.go index 2a0205e0..a107877d 100644 --- a/command/login.go +++ b/command/login.go @@ -5,10 +5,12 @@ import ( "errors" "fmt" "net" + "net/http" "os" "log/slog" + "github.com/coreos/go-oidc" "github.com/pkg/browser" "github.com/riotgames/key-conjurer/oauth2" "github.com/spf13/pflag" @@ -49,12 +51,13 @@ func (c *LoginCommand) Parse(flags *pflag.FlagSet, args []string) error { return nil } -func (c LoginCommand) Run(ctx context.Context, config *Config) error { +func (c LoginCommand) RunContext(ctx context.Context, config *Config) error { if !HasTokenExpired(config.Tokens) { return nil } - oauthCfg, err := oauth2.DiscoverConfig(ctx, c.OIDCDomain, c.ClientID) + client := &http.Client{Transport: LogRoundTripper{http.DefaultTransport}} + oauthCfg, err := oauth2.DiscoverConfig(oidc.ClientContext(ctx, client), c.OIDCDomain, c.ClientID) if err != nil { return err } @@ -92,6 +95,10 @@ func (c LoginCommand) Run(ctx context.Context, config *Config) error { return config.SaveOAuthToken(accessToken, idToken) } +func (c LoginCommand) Run(config *Config) error { + return c.RunContext(context.Background(), config) +} + var ErrNoPortsAvailable = errors.New("no ports available") // findFirstFreePort will attempt to open a network listener for each port in turn, and return the first one that succeeded. diff --git a/command/root.go b/command/root.go index de894cae..aa14422a 100644 --- a/command/root.go +++ b/command/root.go @@ -3,14 +3,12 @@ package command import ( "context" "fmt" - "net/http" "os" "path/filepath" "runtime" "time" "github.com/alecthomas/kong" - "github.com/coreos/go-oidc" "github.com/mitchellh/go-homedir" "github.com/spf13/cobra" ) @@ -114,9 +112,6 @@ type CLI struct { } func Execute(ctx context.Context, args []string) error { - client := &http.Client{Transport: LogRoundTripper{http.DefaultTransport}} - ctx = oidc.ClientContext(ctx, client) - var cli CLI k, err := kong.New(&cli) if err != nil { @@ -128,5 +123,9 @@ func Execute(ctx context.Context, args []string) error { return err } - return kongCtx.Run(kong.Bind(ctx)) + var config Config + return kongCtx.Run( + // This should be moved to a hook in Root so that we can use flags to load the config. + &config, + ) } From fb8504fe72ac0fd5ca8842f53a684390dfd0a583 Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Tue, 12 Nov 2024 11:09:24 -0800 Subject: [PATCH 10/30] Add Config parsing --- command/root.go | 123 ++++++++++++++++++++++-------------------------- 1 file changed, 55 insertions(+), 68 deletions(-) diff --git a/command/root.go b/command/root.go index aa14422a..b1a6bd13 100644 --- a/command/root.go +++ b/command/root.go @@ -6,7 +6,6 @@ import ( "os" "path/filepath" "runtime" - "time" "github.com/alecthomas/kong" "github.com/mitchellh/go-homedir" @@ -18,15 +17,9 @@ var ( FlagClientID = "client-id" FlagConfigPath = "config" FlagQuiet = "quiet" - FlagTimeout = "timeout" ) func init() { - rootCmd.PersistentFlags().String(FlagOIDCDomain, OIDCDomain, "The domain name of your OIDC server") - rootCmd.PersistentFlags().String(FlagClientID, ClientID, "The OAuth2 Client ID for the application registered with your OIDC server") - rootCmd.PersistentFlags().Int(FlagTimeout, 120, "the amount of time in seconds to wait for keyconjurer to respond") - rootCmd.PersistentFlags().String(FlagConfigPath, "~/.keyconjurerrc", "path to .keyconjurerrc file") - rootCmd.PersistentFlags().Bool(FlagQuiet, false, "tells the CLI to be quiet; stdout will not contain human-readable informational messages") rootCmd.AddCommand(accountsCmd) rootCmd.AddCommand(getCmd) rootCmd.AddCommand(setCmd) @@ -35,9 +28,6 @@ func init() { rootCmd.AddCommand(&unaliasCmd) rootCmd.AddCommand(&rolesCmd) rootCmd.SetVersionTemplate("{{.Version}}\n") - - rootCmd.PersistentFlags().MarkHidden(FlagOIDCDomain) - rootCmd.PersistentFlags().MarkHidden(FlagClientID) } // rootCmd represents the base command when called without any subcommands @@ -45,62 +35,9 @@ var rootCmd = &cobra.Command{ Use: "keyconjurer", Version: fmt.Sprintf("keyconjurer-%s-%s %s (%s)", runtime.GOOS, runtime.GOARCH, Version, BuildTimestamp), Short: "Retrieve temporary cloud credentials.", - Long: `KeyConjurer retrieves temporary credentials from Okta with the assistance of an optional API. - -To get started run the following commands: - keyconjurer login - keyconjurer accounts - keyconjurer get -`, FParseErrWhitelist: cobra.FParseErrWhitelist{ UnknownFlags: true, }, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - var config Config - // The error of this function call is only non-nil if the flag was not provided or is not a string. - configPath, _ := cmd.Flags().GetString(FlagConfigPath) - if expanded, err := homedir.Expand(configPath); err == nil { - configPath = expanded - } - - file, err := EnsureConfigFileExists(configPath) - if err != nil { - return err - } - - if err := config.Read(file); err != nil { - return err - } - - // We don't care about this being cancelled. - timeout, _ := cmd.Flags().GetInt(FlagTimeout) - nextCtx, _ := context.WithTimeout(cmd.Context(), time.Duration(timeout)*time.Second) - cmd.SetContext(ConfigContext(nextCtx, &config, configPath)) - return nil - }, - PersistentPostRunE: func(cmd *cobra.Command, _ []string) error { - config := ConfigFromCommand(cmd) - path := ConfigPathFromCommand(cmd) - if expanded, err := homedir.Expand(path); err == nil { - path = expanded - } - - // Do not use EnsureConfigFileExists here! - // EnsureConfigFileExists opens the file in append mode. - // If we open the file in append mode, we'll always append to the file. If we open the file in truncate mode before reading from the file, the content will be truncated _before we read from it_, which will cause a users configuration to be discarded every time we run the program. - - if err := os.MkdirAll(filepath.Dir(path), os.ModeDir|os.FileMode(0755)); err != nil { - return err - } - - file, err := os.Create(path) - if err != nil { - return fmt.Errorf("unable to create %s reason: %w", path, err) - } - - defer file.Close() - return config.Write(file) - }, SilenceErrors: true, SilenceUsage: true, } @@ -109,6 +46,60 @@ type CLI struct { Login LoginCommand `cmd:"" help:"Authenticate with KeyConjurer."` Get GetCommand `cmd:"" help:"Retrieve temporary cloud credentials."` // Switch SwitchCommand `cmd:"" help:"Switch between accounts."` + + Config Config `kong:"-"` + ConfigPath string `help:"path to .keyconjurerrc file" default:"~/.keyconjurerrc"` + Quiet bool `help:"tells the CLI to be quiet; stdout will not contain human-readable informational messages"` +} + +func (CLI) Help() string { + return `KeyConjurer retrieves temporary credentials from Okta with the assistance of an optional API. + +To get started run the following commands: + keyconjurer login + keyconjurer accounts + keyconjurer get ` +} + +func (c *CLI) BeforeApply(ctx *kong.Context, trace *kong.Path) error { + if expanded, err := homedir.Expand(c.ConfigPath); err == nil { + c.ConfigPath = expanded + } + + file, err := EnsureConfigFileExists(c.ConfigPath) + if err != nil { + return err + } + + err = c.Config.Read(file) + if err != nil { + return err + } + + // Make *Config available to all sub-commands. + // This must be &c.Config because c.Config is not a pointer. + ctx.Bind(&c.Config) + return nil +} + +func (c *CLI) AfterRun(ctx *kong.Context) error { + if expanded, err := homedir.Expand(c.ConfigPath); err == nil { + c.ConfigPath = expanded + } + + // Do not use EnsureConfigFileExists here! EnsureConfigFileExists opens the file in append mode. + // If we open the file in append mode, we'll always append to the file. If we open the file in truncate mode before reading from the file, the content will be truncated _before we read from it_, which will cause a users configuration to be discarded every time we run the program. + if err := os.MkdirAll(filepath.Dir(c.ConfigPath), os.ModeDir|os.FileMode(0755)); err != nil { + return err + } + + file, err := os.Create(c.ConfigPath) + if err != nil { + return fmt.Errorf("unable to create %s reason: %w", c.ConfigPath, err) + } + + defer file.Close() + return c.Config.Write(file) } func Execute(ctx context.Context, args []string) error { @@ -123,9 +114,5 @@ func Execute(ctx context.Context, args []string) error { return err } - var config Config - return kongCtx.Run( - // This should be moved to a hook in Root so that we can use flags to load the config. - &config, - ) + return kongCtx.Run() } From ac68d5f2d36f9745ea0f0e60f514d891e04f5f32 Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Tue, 12 Nov 2024 11:18:55 -0800 Subject: [PATCH 11/30] Add constants for OIDC Domain, Server Address and Client ID --- command/accounts.go | 4 ++++ command/get.go | 13 ++++++++----- command/login.go | 4 ++-- command/root.go | 8 +++++++- 4 files changed, 21 insertions(+), 8 deletions(-) diff --git a/command/accounts.go b/command/accounts.go index 7227dbb8..bddc8553 100644 --- a/command/accounts.go +++ b/command/accounts.go @@ -26,6 +26,10 @@ func init() { accountsCmd.Flags().String(FlagServerAddress, ServerAddress, "The address of the account server. This does not usually need to be changed or specified.") } +type AccountsCommand struct { + ServerAddress string `help:"The address of the account server. This does not usually need to be changed or specified." hidden:"" env:"KEYCONJURER_SERVER_ADDRESS" default:"${server_address}"` +} + var accountsCmd = &cobra.Command{ Use: "accounts", Short: "Prints and optionally refreshes the list of accounts you have access to.", diff --git a/command/get.go b/command/get.go index 5be53080..e0f52e8b 100644 --- a/command/get.go +++ b/command/get.go @@ -55,11 +55,14 @@ func resolveApplicationInfo(cfg *Config, bypassCache bool, nameOrID string) (*Ac } type GetCommand struct { - AccountNameOrID string `arg:""` - TimeToLive uint - TimeRemaining uint - OutputType, ShellType, RoleName, AWSCLIPath, OIDCDomain, ClientID, Region string - Login, URLOnly, NoBrowser, BypassCache, MachineOutput bool + OIDCDomain string `help:"The domain name of your OIDC server" hidden:"" env:"KEYCONJURER_OIDC_DOMAIN" default:"${oidc_domain}"` + ClientID string `help:"The client ID of your OIDC server" hidden:"" env:"KEYCONJURER_CLIENT_ID" default:"${client_id}"` + + AccountNameOrID string `arg:""` + TimeToLive uint + TimeRemaining uint + OutputType, ShellType, RoleName, AWSCLIPath, Region string + Login, URLOnly, NoBrowser, BypassCache, MachineOutput bool UsageFunc func() error `kong:"-"` PrintErrln func(...any) `kong:"-"` diff --git a/command/login.go b/command/login.go index a107877d..b252a563 100644 --- a/command/login.go +++ b/command/login.go @@ -32,8 +32,8 @@ func ShouldUseMachineOutput(flags *pflag.FlagSet) bool { } type LoginCommand struct { - OIDCDomain string `help:"The domain name of your OIDC server" hidden:""` - ClientID string `help:"The client ID of your OIDC server" hidden:""` + OIDCDomain string `help:"The domain name of your OIDC server" hidden:"" env:"KEYCONJURER_OIDC_DOMAIN" default:"${oidc_domain}"` + ClientID string `help:"The client ID of your OIDC server" hidden:"" env:"KEYCONJURER_CLIENT_ID" default:"${client_id}"` MachineOutput bool `kong:"-"` NoBrowser bool `kong:"-"` } diff --git a/command/root.go b/command/root.go index b1a6bd13..a251c3f0 100644 --- a/command/root.go +++ b/command/root.go @@ -103,8 +103,14 @@ func (c *CLI) AfterRun(ctx *kong.Context) error { } func Execute(ctx context.Context, args []string) error { + vars := kong.Vars{ + "client_id": ClientID, + "server_address": ServerAddress, + "oidc_domain": OIDCDomain, + } + var cli CLI - k, err := kong.New(&cli) + k, err := kong.New(&cli, vars) if err != nil { return err } From fd92052e68f8c116b5a0ac7e059b3f8e6604d41a Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Tue, 12 Nov 2024 11:40:05 -0800 Subject: [PATCH 12/30] Annotate the get command --- command/get.go | 93 ++++++++++++++----------------------------------- command/root.go | 6 ++-- 2 files changed, 29 insertions(+), 70 deletions(-) diff --git a/command/get.go b/command/get.go index e0f52e8b..69b309b3 100644 --- a/command/get.go +++ b/command/get.go @@ -11,7 +11,6 @@ import ( "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/riotgames/key-conjurer/oauth2" - "github.com/spf13/cobra" ) var ( @@ -32,21 +31,6 @@ var ( permittedShellTypes = []string{shellTypePowershell, shellTypeBash, shellTypeBasic, shellTypeInfer} ) -func init() { - getCmd.Flags().String(FlagRegion, "us-west-2", "The AWS region to use") - getCmd.Flags().Uint(FlagTimeToLive, 1, "The key timeout in hours from 1 to 8.") - getCmd.Flags().UintP(FlagTimeRemaining, "t", DefaultTimeRemaining, "Request new keys if there are no keys in the environment or the current keys expire within minutes. Defaults to 60.") - getCmd.Flags().StringP(FlagRoleName, "r", "", "The name of the role to assume.") - getCmd.Flags().String(FlagRoleSessionName, "KeyConjurer-AssumeRole", "the name of the role session name that will show up in CloudTrail logs") - getCmd.Flags().StringP(FlagOutputType, "o", outputTypeEnvironmentVariable, "Format to save new credentials in. Supported outputs: env, awscli") - getCmd.Flags().String(FlagShellType, shellTypeInfer, "If output type is env, determines which format to output credentials in - by default, the format is inferred based on the execution environment. WSL users may wish to overwrite this to `bash`") - getCmd.Flags().Bool(FlagBypassCache, false, "Do not check the cache for accounts and send the application ID as-is to Okta. This is useful if you have an ID you know is an Okta application ID and it is not stored in your local account cache.") - getCmd.Flags().Bool(FlagLogin, false, "Login to Okta before running the command") - getCmd.Flags().String(FlagAWSCLIPath, "~/.aws/", "Path for directory used by the aws CLI") - getCmd.Flags().BoolP(FlagURLOnly, "u", false, "Print only the URL to visit rather than a user-friendly message") - getCmd.Flags().BoolP(FlagNoBrowser, "b", false, "Do not open a browser window, printing the URL instead") -} - func resolveApplicationInfo(cfg *Config, bypassCache bool, nameOrID string) (*Account, bool) { if bypassCache { return &Account{ID: nameOrID, Name: nameOrID}, true @@ -55,42 +39,31 @@ func resolveApplicationInfo(cfg *Config, bypassCache bool, nameOrID string) (*Ac } type GetCommand struct { - OIDCDomain string `help:"The domain name of your OIDC server" hidden:"" env:"KEYCONJURER_OIDC_DOMAIN" default:"${oidc_domain}"` - ClientID string `help:"The client ID of your OIDC server" hidden:"" env:"KEYCONJURER_CLIENT_ID" default:"${client_id}"` - - AccountNameOrID string `arg:""` - TimeToLive uint - TimeRemaining uint - OutputType, ShellType, RoleName, AWSCLIPath, Region string - Login, URLOnly, NoBrowser, BypassCache, MachineOutput bool - - UsageFunc func() error `kong:"-"` - PrintErrln func(...any) `kong:"-"` + OIDCDomain string `help:"The domain name of your OIDC server" hidden:"" env:"KEYCONJURER_OIDC_DOMAIN" default:"${oidc_domain}"` + ClientID string `help:"The client ID of your OIDC server" hidden:"" env:"KEYCONJURER_CLIENT_ID" default:"${client_id}"` + AccountNameOrID string `arg:""` + TimeToLive uint `placeholder:"hours" help:"The key timeout in hours from 1 to 8." default:"1" name:"ttl"` + TimeRemaining uint `placeholder:"minutes" help:"Request new keys if there are no keys in the environment or the current keys expire within minutes." default:"5" short:"t"` + AWSCLIPath string `help:"Path to the AWS CLI configuration directory." default:"~/.aws/" name:"awscli"` + ShellType string `name:"shell" help:"If output type is env, determines which format to output credentials in. WSL users may wish to overwrite this to \"bash\"." default:"infer" enum:"infer,basic,powershell,bash"` + URLOnly bool `help:"Print only the URL to visit rather than a user-friendly message." short:"u"` + Browser bool `help:"Open the browser to the Okta URL. If false, a URL will be printed to the command line instead." default:"true" negatable:"" short:"b"` + OutputType string `help:"Format to save new credentials in." default:"env" enum:"env,awscli" short:"o" default:"env" name:"out"` + Login bool `help:"Login to Okta before running the command if the tokens have expired."` + RoleName string `help:"The name of the role to assume." short:"r" name:"role"` + SessionName string `help:"The name of the role session name that will show up in CloudTrail logs." default:"KeyConjurer-AssumeRole"` + Region string `help:"The AWS region to use." env:"AWS_REGION" default:"us-west-2"` + BypassCache bool `help:"Do not check the cache for accounts and send the application ID as-is to Okta. This is useful if you have an ID you know is an Okta application ID and it is not stored in your local account cache." hidden:""` + + UsageFunc func() error `kong:"-"` + PrintErrln func(...any) `kong:"-"` + MachineOutput bool `kong:"-"` } -func (g *GetCommand) Parse(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() - g.OIDCDomain, _ = flags.GetString(FlagOIDCDomain) - g.ClientID, _ = flags.GetString(FlagClientID) - g.TimeToLive, _ = flags.GetUint(FlagTimeToLive) - g.TimeRemaining, _ = flags.GetUint(FlagTimeRemaining) - g.OutputType, _ = flags.GetString(FlagOutputType) - g.ShellType, _ = flags.GetString(FlagShellType) - g.RoleName, _ = flags.GetString(FlagRoleName) - g.AWSCLIPath, _ = flags.GetString(FlagAWSCLIPath) - g.Login, _ = flags.GetBool(FlagLogin) - g.URLOnly, _ = flags.GetBool(FlagURLOnly) - g.NoBrowser, _ = flags.GetBool(FlagNoBrowser) - g.BypassCache, _ = flags.GetBool(FlagBypassCache) - g.Region, _ = flags.GetString(FlagRegion) - g.UsageFunc = cmd.Usage - g.PrintErrln = cmd.PrintErrln - g.MachineOutput = ShouldUseMachineOutput(flags) || g.URLOnly - if len(args) == 0 { - return fmt.Errorf("account name or alias is required") - } - g.AccountNameOrID = args[0] - return nil +func (g GetCommand) Help() string { + return `Retrieves temporary cloud API credentials for the specified account. + +A role must be specified when using this command through the --role flag. You may list the roles you can assume through the roles command, and the accounts through the accounts command.` } func (g GetCommand) Validate() error { @@ -109,6 +82,8 @@ func (g GetCommand) printUsage() error { } func (g GetCommand) RunContext(ctx context.Context, cfg *Config) error { + // g.MachineOutput = ShouldUseMachineOutput(flags) || g.URLOnly + if HasTokenExpired(cfg.Tokens) { if !g.Login { return ErrTokensExpiredOrAbsent @@ -118,7 +93,7 @@ func (g GetCommand) RunContext(ctx context.Context, cfg *Config) error { OIDCDomain: g.OIDCDomain, ClientID: g.ClientID, MachineOutput: g.MachineOutput, - NoBrowser: g.NoBrowser, + NoBrowser: !g.Browser, } if err := loginCommand.RunContext(ctx, cfg); err != nil { @@ -222,22 +197,6 @@ func (g GetCommand) fetchNewCredentials(ctx context.Context, account Account, cf }, nil } -var getCmd = &cobra.Command{ - Use: "get ", - Short: "Retrieves temporary cloud API credentials.", - Long: `Retrieves temporary cloud API credentials for the specified account. It sends a push request to the first Duo device it finds associated with your account. - -A role must be specified when using this command through the --role flag. You may list the roles you can assume through the roles command.`, - RunE: func(cmd *cobra.Command, args []string) error { - var getCmd GetCommand - if err := getCmd.Parse(cmd, args); err != nil { - return err - } - - return getCmd.RunContext(cmd.Context(), ConfigFromCommand(cmd)) - }, -} - func echoCredentials(id, name string, credentials CloudCredentials, outputType, shellType, cliPath string) error { switch outputType { case outputTypeEnvironmentVariable: diff --git a/command/root.go b/command/root.go index a251c3f0..f755839c 100644 --- a/command/root.go +++ b/command/root.go @@ -21,7 +21,6 @@ var ( func init() { rootCmd.AddCommand(accountsCmd) - rootCmd.AddCommand(getCmd) rootCmd.AddCommand(setCmd) rootCmd.AddCommand(&switchCmd) rootCmd.AddCommand(&aliasCmd) @@ -47,9 +46,10 @@ type CLI struct { Get GetCommand `cmd:"" help:"Retrieve temporary cloud credentials."` // Switch SwitchCommand `cmd:"" help:"Switch between accounts."` - Config Config `kong:"-"` - ConfigPath string `help:"path to .keyconjurerrc file" default:"~/.keyconjurerrc"` + ConfigPath string `help:"path to .keyconjurerrc file" default:"~/.keyconjurerrc" name:"config"` Quiet bool `help:"tells the CLI to be quiet; stdout will not contain human-readable informational messages"` + + Config Config `kong:"-"` } func (CLI) Help() string { From 781acae2cb6fce3cb8c62aba8670756473763cd2 Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Tue, 12 Nov 2024 11:47:16 -0800 Subject: [PATCH 13/30] Add version flag --- command/root.go | 29 +++++++++++++++++------------ command/version.go | 17 +++++++++++++++++ 2 files changed, 34 insertions(+), 12 deletions(-) create mode 100644 command/version.go diff --git a/command/root.go b/command/root.go index f755839c..2833aedb 100644 --- a/command/root.go +++ b/command/root.go @@ -31,9 +31,8 @@ func init() { // rootCmd represents the base command when called without any subcommands var rootCmd = &cobra.Command{ - Use: "keyconjurer", - Version: fmt.Sprintf("keyconjurer-%s-%s %s (%s)", runtime.GOOS, runtime.GOARCH, Version, BuildTimestamp), - Short: "Retrieve temporary cloud credentials.", + Use: "keyconjurer", + Short: ".", FParseErrWhitelist: cobra.FParseErrWhitelist{ UnknownFlags: true, }, @@ -46,8 +45,9 @@ type CLI struct { Get GetCommand `cmd:"" help:"Retrieve temporary cloud credentials."` // Switch SwitchCommand `cmd:"" help:"Switch between accounts."` - ConfigPath string `help:"path to .keyconjurerrc file" default:"~/.keyconjurerrc" name:"config"` - Quiet bool `help:"tells the CLI to be quiet; stdout will not contain human-readable informational messages"` + ConfigPath string `help:"path to .keyconjurerrc file" default:"~/.keyconjurerrc" name:"config"` + Quiet bool `help:"tells the CLI to be quiet; stdout will not contain human-readable informational messages"` + Version VersionFlag `help:"Show version information." short:"v"` Config Config `kong:"-"` } @@ -103,14 +103,19 @@ func (c *CLI) AfterRun(ctx *kong.Context) error { } func Execute(ctx context.Context, args []string) error { - vars := kong.Vars{ - "client_id": ClientID, - "server_address": ServerAddress, - "oidc_domain": OIDCDomain, - } - var cli CLI - k, err := kong.New(&cli, vars) + k, err := kong.New(&cli, + kong.Name("keyconjurer"), + kong.Description("Retrieve temporary cloud credentials."), + kong.UsageOnError(), + kong.Vars{ + "client_id": ClientID, + "server_address": ServerAddress, + "oidc_domain": OIDCDomain, + "version": fmt.Sprintf("keyconjurer-%s-%s %s (%s)", runtime.GOOS, runtime.GOARCH, Version, BuildTimestamp), + }, + ) + if err != nil { return err } diff --git a/command/version.go b/command/version.go new file mode 100644 index 00000000..99058435 --- /dev/null +++ b/command/version.go @@ -0,0 +1,17 @@ +package command + +import ( + "fmt" + + "github.com/alecthomas/kong" +) + +type VersionFlag string + +func (v VersionFlag) Decode(ctx *kong.DecodeContext) error { return nil } +func (v VersionFlag) IsBool() bool { return true } +func (v VersionFlag) BeforeApply(app *kong.Kong, vars kong.Vars) error { + fmt.Println(vars["version"]) + app.Exit(0) + return nil +} From c18db176302770a430e4dc71417358f0b3379958 Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Tue, 12 Nov 2024 11:48:36 -0800 Subject: [PATCH 14/30] Remove rootCmd --- command/root.go | 39 ++++++-------------- command/root_test.go | 84 -------------------------------------------- command/version.go | 17 --------- main.go | 3 +- 4 files changed, 11 insertions(+), 132 deletions(-) delete mode 100644 command/root_test.go delete mode 100644 command/version.go diff --git a/command/root.go b/command/root.go index 2833aedb..496a611f 100644 --- a/command/root.go +++ b/command/root.go @@ -1,7 +1,6 @@ package command import ( - "context" "fmt" "os" "path/filepath" @@ -9,7 +8,6 @@ import ( "github.com/alecthomas/kong" "github.com/mitchellh/go-homedir" - "github.com/spf13/cobra" ) var ( @@ -19,35 +17,14 @@ var ( FlagQuiet = "quiet" ) -func init() { - rootCmd.AddCommand(accountsCmd) - rootCmd.AddCommand(setCmd) - rootCmd.AddCommand(&switchCmd) - rootCmd.AddCommand(&aliasCmd) - rootCmd.AddCommand(&unaliasCmd) - rootCmd.AddCommand(&rolesCmd) - rootCmd.SetVersionTemplate("{{.Version}}\n") -} - -// rootCmd represents the base command when called without any subcommands -var rootCmd = &cobra.Command{ - Use: "keyconjurer", - Short: ".", - FParseErrWhitelist: cobra.FParseErrWhitelist{ - UnknownFlags: true, - }, - SilenceErrors: true, - SilenceUsage: true, -} - type CLI struct { Login LoginCommand `cmd:"" help:"Authenticate with KeyConjurer."` Get GetCommand `cmd:"" help:"Retrieve temporary cloud credentials."` // Switch SwitchCommand `cmd:"" help:"Switch between accounts."` - ConfigPath string `help:"path to .keyconjurerrc file" default:"~/.keyconjurerrc" name:"config"` - Quiet bool `help:"tells the CLI to be quiet; stdout will not contain human-readable informational messages"` - Version VersionFlag `help:"Show version information." short:"v"` + ConfigPath string `help:"path to .keyconjurerrc file" default:"~/.keyconjurerrc" name:"config"` + Quiet bool `help:"tells the CLI to be quiet; stdout will not contain human-readable informational messages"` + Version kong.VersionFlag `help:"Show version information." short:"v" flag:""` Config Config `kong:"-"` } @@ -102,9 +79,9 @@ func (c *CLI) AfterRun(ctx *kong.Context) error { return c.Config.Write(file) } -func Execute(ctx context.Context, args []string) error { - var cli CLI - k, err := kong.New(&cli, +func newKong(cli *CLI) (*kong.Kong, error) { + return kong.New( + cli, kong.Name("keyconjurer"), kong.Description("Retrieve temporary cloud credentials."), kong.UsageOnError(), @@ -115,7 +92,11 @@ func Execute(ctx context.Context, args []string) error { "version": fmt.Sprintf("keyconjurer-%s-%s %s (%s)", runtime.GOOS, runtime.GOARCH, Version, BuildTimestamp), }, ) +} +func Execute(args []string) error { + var cli CLI + k, err := newKong(&cli) if err != nil { return err } diff --git a/command/root_test.go b/command/root_test.go deleted file mode 100644 index 3864e3d7..00000000 --- a/command/root_test.go +++ /dev/null @@ -1,84 +0,0 @@ -package command - -import ( - "bytes" - "fmt" - "runtime" - "testing" - - "github.com/spf13/cobra" - "github.com/stretchr/testify/assert" -) - -// Required to reset Cobra state between Test runs. -// If you add new tests that change state, you may need to -// add code here to reset the side effects of other tests -func resetCobra(cmd *cobra.Command) { - cmd.Flags().Set("help", "false") - cmd.Flags().Set("version", "false") -} - -func execute(cmd *cobra.Command, args ...string) (string, error) { - var buf bytes.Buffer - cmd.SetArgs(args) - cmd.SetOutput(&buf) - err := cmd.Execute() - return buf.String(), err -} - -func TestVersionFlag(t *testing.T) { - t.Cleanup(func() { - resetCobra(rootCmd) - }) - - output, err := execute(rootCmd, "--version") - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - expected := fmt.Sprintf("keyconjurer-%s-%s TBD (BuildTimestamp is not set)\n", runtime.GOOS, runtime.GOARCH) - assert.Equal(t, output, expected) -} - -func TestVersionShortFlag(t *testing.T) { - t.Cleanup(func() { - resetCobra(rootCmd) - }) - - output, err := execute(rootCmd, "-v") - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - expected := fmt.Sprintf("keyconjurer-%s-%s TBD (BuildTimestamp is not set)\n", runtime.GOOS, runtime.GOARCH) - assert.Equal(t, output, expected) -} - -func TestHelpLongInvalidArgs(t *testing.T) { - t.Cleanup(func() { - resetCobra(rootCmd) - }) - output, err := execute(rootCmd, "get", "-s") - if err != nil { - if err.Error() != "unknown shorthand flag: 's' in -s" { - t.Errorf("Unexpected error: %v", err) - } - } else { - t.Errorf("Unexpected non-error, output=: %v", output) - } -} - -func TestInvalidCommand(t *testing.T) { - t.Cleanup(func() { - resetCobra(rootCmd) - }) - - output, err := execute(rootCmd, "badcommand") - if err != nil { - if err.Error() != "unknown command \"badcommand\" for \"keyconjurer\"" { - t.Errorf("Unexpected error: %v", err) - } - } else { - t.Errorf("Unexpected non-error, output=: %v", output) - } -} diff --git a/command/version.go b/command/version.go deleted file mode 100644 index 99058435..00000000 --- a/command/version.go +++ /dev/null @@ -1,17 +0,0 @@ -package command - -import ( - "fmt" - - "github.com/alecthomas/kong" -) - -type VersionFlag string - -func (v VersionFlag) Decode(ctx *kong.DecodeContext) error { return nil } -func (v VersionFlag) IsBool() bool { return true } -func (v VersionFlag) BeforeApply(app *kong.Kong, vars kong.Vars) error { - fmt.Println(vars["version"]) - app.Exit(0) - return nil -} diff --git a/main.go b/main.go index 2dfa7880..c64d0953 100644 --- a/main.go +++ b/main.go @@ -1,7 +1,6 @@ package main import ( - "context" "errors" "fmt" "os" @@ -43,7 +42,7 @@ func main() { args = append(args, strings.Split(flag, " ")...) } - err := command.Execute(context.Background(), args) + err := command.Execute(args) if IsWindowsPortAccessError(err) { fmt.Fprintf(os.Stderr, "Encountered an issue when opening the port for KeyConjurer: %s\n", err) fmt.Fprintln(os.Stderr, "Consider running `net stop hns` and then `net start hns`") From 25e2bf7752f0deeb7d855405962fdfe480affafe Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Tue, 12 Nov 2024 12:40:27 -0800 Subject: [PATCH 15/30] Add account dumping --- command/accounts.go | 79 +++++++++++++++++++++------------------------ command/get.go | 31 ++++++------------ command/login.go | 37 ++++++--------------- command/root.go | 24 ++++++++------ 4 files changed, 72 insertions(+), 99 deletions(-) diff --git a/command/accounts.go b/command/accounts.go index bddc8553..fa38b0a6 100644 --- a/command/accounts.go +++ b/command/accounts.go @@ -8,69 +8,64 @@ import ( "io" "net/http" "net/url" + "os" "github.com/riotgames/key-conjurer/internal/api" - "github.com/spf13/cobra" "golang.org/x/oauth2" ) var ( FlagNoRefresh = "no-refresh" FlagServerAddress = "server-address" - ErrSessionExpired = errors.New("session expired") ) -func init() { - accountsCmd.Flags().Bool(FlagNoRefresh, false, "Indicate that the account list should not be refreshed when executing this command. This is useful if you're not able to reach the account server.") - accountsCmd.Flags().String(FlagServerAddress, ServerAddress, "The address of the account server. This does not usually need to be changed or specified.") -} - type AccountsCommand struct { + Refresh bool `help:"Refresh the list of accounts." default:"true" negatable:""` ServerAddress string `help:"The address of the account server. This does not usually need to be changed or specified." hidden:"" env:"KEYCONJURER_SERVER_ADDRESS" default:"${server_address}"` } -var accountsCmd = &cobra.Command{ - Use: "accounts", - Short: "Prints and optionally refreshes the list of accounts you have access to.", - RunE: func(cmd *cobra.Command, args []string) error { - config := ConfigFromCommand(cmd) - stdOut := cmd.OutOrStdout() - noRefresh, _ := cmd.Flags().GetBool(FlagNoRefresh) - loud := !ShouldUseMachineOutput(cmd.Flags()) - if noRefresh { - config.DumpAccounts(stdOut, loud) - - if loud { - // intentionally uses PrintErrf was a warning - cmd.PrintErrf("--%s was specified - these results may be out of date, and you may not have access to accounts in this list.\n", FlagNoRefresh) - } - - return nil - } +func (a AccountsCommand) Help() string { + return "Prints and optionally refreshes the list of accounts you have access to." +} - serverAddr, _ := cmd.Flags().GetString(FlagServerAddress) - serverAddrURI, err := url.Parse(serverAddr) - if err != nil { - return genericError{ - ExitCode: ExitCodeValueError, - Message: fmt.Sprintf("--%s had an invalid value: %s\n", FlagServerAddress, err), - } - } +func (a AccountsCommand) RunContext(ctx context.Context, globals *Globals, config *Config) error { + loud := isPiped() || globals.Quiet + if !a.Refresh { + config.DumpAccounts(os.Stdout, loud) - if HasTokenExpired(config.Tokens) { - return ErrTokensExpiredOrAbsent + if loud { + // intentionally uses Fprintf was a warning + fmt.Fprintf(os.Stderr, "--no-refresh was specified - these results may be out of date, and you may not have access to accounts in this list.\n") } - accounts, err := refreshAccounts(cmd.Context(), serverAddrURI, config.Tokens) - if err != nil { - return fmt.Errorf("error refreshing accounts: %w", err) + return nil + } + + serverAddrURI, err := url.Parse(a.ServerAddress) + if err != nil { + return genericError{ + ExitCode: ExitCodeValueError, + Message: fmt.Sprintf("--%s had an invalid value: %s\n", FlagServerAddress, err), } + } - config.UpdateAccounts(accounts) - config.DumpAccounts(stdOut, loud) - return nil - }, + if HasTokenExpired(config.Tokens) { + return ErrTokensExpiredOrAbsent + } + + accounts, err := refreshAccounts(ctx, serverAddrURI, config.Tokens) + if err != nil { + return fmt.Errorf("error refreshing accounts: %w", err) + } + + config.UpdateAccounts(accounts) + config.DumpAccounts(os.Stdout, loud) + return nil +} + +func (a AccountsCommand) Run(globals *Globals, config *Config) error { + return a.RunContext(context.Background(), globals, config) } func refreshAccounts(ctx context.Context, serverAddr *url.URL, ts oauth2.TokenSource) ([]Account, error) { diff --git a/command/get.go b/command/get.go index 69b309b3..d0290ebc 100644 --- a/command/get.go +++ b/command/get.go @@ -39,8 +39,6 @@ func resolveApplicationInfo(cfg *Config, bypassCache bool, nameOrID string) (*Ac } type GetCommand struct { - OIDCDomain string `help:"The domain name of your OIDC server" hidden:"" env:"KEYCONJURER_OIDC_DOMAIN" default:"${oidc_domain}"` - ClientID string `help:"The client ID of your OIDC server" hidden:"" env:"KEYCONJURER_CLIENT_ID" default:"${client_id}"` AccountNameOrID string `arg:""` TimeToLive uint `placeholder:"hours" help:"The key timeout in hours from 1 to 8." default:"1" name:"ttl"` TimeRemaining uint `placeholder:"minutes" help:"Request new keys if there are no keys in the environment or the current keys expire within minutes." default:"5" short:"t"` @@ -55,9 +53,8 @@ type GetCommand struct { Region string `help:"The AWS region to use." env:"AWS_REGION" default:"us-west-2"` BypassCache bool `help:"Do not check the cache for accounts and send the application ID as-is to Okta. This is useful if you have an ID you know is an Okta application ID and it is not stored in your local account cache." hidden:""` - UsageFunc func() error `kong:"-"` - PrintErrln func(...any) `kong:"-"` - MachineOutput bool `kong:"-"` + UsageFunc func() error `kong:"-"` + PrintErrln func(...any) `kong:"-"` } func (g GetCommand) Help() string { @@ -81,22 +78,14 @@ func (g GetCommand) printUsage() error { return g.UsageFunc() } -func (g GetCommand) RunContext(ctx context.Context, cfg *Config) error { - // g.MachineOutput = ShouldUseMachineOutput(flags) || g.URLOnly - +func (g GetCommand) RunContext(ctx context.Context, globals *Globals, cfg *Config) error { if HasTokenExpired(cfg.Tokens) { if !g.Login { return ErrTokensExpiredOrAbsent } - loginCommand := LoginCommand{ - OIDCDomain: g.OIDCDomain, - ClientID: g.ClientID, - MachineOutput: g.MachineOutput, - NoBrowser: !g.Browser, - } - - if err := loginCommand.RunContext(ctx, cfg); err != nil { + var loginCommand LoginCommand + if err := loginCommand.RunContext(ctx, globals, cfg); err != nil { return err } } @@ -130,7 +119,7 @@ func (g GetCommand) RunContext(ctx context.Context, cfg *Config) error { credentials := LoadAWSCredentialsFromEnvironment() if !credentials.ValidUntil(account, time.Duration(g.TimeRemaining)*time.Minute) { - newCredentials, err := g.fetchNewCredentials(ctx, *account, cfg) + newCredentials, err := g.fetchNewCredentials(ctx, *account, globals, cfg) if err != nil { return err } @@ -145,12 +134,12 @@ func (g GetCommand) RunContext(ctx context.Context, cfg *Config) error { return echoCredentials(accountID, accountID, credentials, g.OutputType, g.ShellType, g.AWSCLIPath) } -func (g GetCommand) Run(cfg *Config) error { - return g.RunContext(context.Background(), cfg) +func (g GetCommand) Run(globals *Globals, cfg *Config) error { + return g.RunContext(context.Background(), globals, cfg) } -func (g GetCommand) fetchNewCredentials(ctx context.Context, account Account, cfg *Config) (*CloudCredentials, error) { - samlResponse, assertionStr, err := oauth2.DiscoverConfigAndExchangeTokenForAssertion(ctx, cfg.Tokens.AccessToken, cfg.Tokens.IDToken, g.OIDCDomain, g.ClientID, account.ID) +func (g GetCommand) fetchNewCredentials(ctx context.Context, account Account, globals *Globals, cfg *Config) (*CloudCredentials, error) { + samlResponse, assertionStr, err := oauth2.DiscoverConfigAndExchangeTokenForAssertion(ctx, cfg.Tokens.AccessToken, cfg.Tokens.IDToken, globals.OIDCDomain, globals.ClientID, account.ID) if err != nil { return nil, err } diff --git a/command/login.go b/command/login.go index b252a563..51bf96c4 100644 --- a/command/login.go +++ b/command/login.go @@ -13,7 +13,6 @@ import ( "github.com/coreos/go-oidc" "github.com/pkg/browser" "github.com/riotgames/key-conjurer/oauth2" - "github.com/spf13/pflag" ) var ( @@ -21,43 +20,27 @@ var ( FlagNoBrowser = "no-browser" ) -// ShouldUseMachineOutput indicates whether or not we should write to standard output as if the user is a machine. -// -// What this means is implementation specific, but this usually indicates the user is trying to use this program in a script and we should avoid user-friendly output messages associated with values a user might find useful. -func ShouldUseMachineOutput(flags *pflag.FlagSet) bool { - quiet, _ := flags.GetBool(FlagQuiet) +func isPiped() bool { fi, _ := os.Stdout.Stat() - isPiped := fi.Mode()&os.ModeCharDevice == 0 - return isPiped || quiet + return fi.Mode()&os.ModeCharDevice == 0 } type LoginCommand struct { - OIDCDomain string `help:"The domain name of your OIDC server" hidden:"" env:"KEYCONJURER_OIDC_DOMAIN" default:"${oidc_domain}"` - ClientID string `help:"The client ID of your OIDC server" hidden:"" env:"KEYCONJURER_CLIENT_ID" default:"${client_id}"` - MachineOutput bool `kong:"-"` - NoBrowser bool `kong:"-"` + URLOnly bool `help:"Print only the URL to visit rather than a user-friendly message." short:"u"` + Browser bool `help:"Open the browser to the Okta URL. If false, a URL will be printed to the command line instead." default:"true" negatable:"" short:"b"` } func (c LoginCommand) Help() string { return "Login to KeyConjurer using OAuth2. You will be required to open the URL printed to the console or scan a QR code." } -func (c *LoginCommand) Parse(flags *pflag.FlagSet, args []string) error { - c.OIDCDomain, _ = flags.GetString(FlagOIDCDomain) - c.ClientID, _ = flags.GetString(FlagClientID) - c.NoBrowser, _ = flags.GetBool(FlagNoBrowser) - urlOnly, _ := flags.GetBool(FlagURLOnly) - c.MachineOutput = ShouldUseMachineOutput(flags) || urlOnly - return nil -} - -func (c LoginCommand) RunContext(ctx context.Context, config *Config) error { +func (c LoginCommand) RunContext(ctx context.Context, globals *Globals, config *Config) error { if !HasTokenExpired(config.Tokens) { return nil } client := &http.Client{Transport: LogRoundTripper{http.DefaultTransport}} - oauthCfg, err := oauth2.DiscoverConfig(oidc.ClientContext(ctx, client), c.OIDCDomain, c.ClientID) + oauthCfg, err := oauth2.DiscoverConfig(oidc.ClientContext(ctx, client), globals.OIDCDomain, globals.ClientID) if err != nil { return err } @@ -79,8 +62,8 @@ func (c LoginCommand) RunContext(ctx context.Context, config *Config) error { OnDisplayURL: openBrowserToURL, } - if c.NoBrowser { - if c.MachineOutput { + if !c.Browser { + if isPiped() || globals.Quiet { handler.OnDisplayURL = printURLToConsole } else { handler.OnDisplayURL = friendlyPrintURLToConsole @@ -95,8 +78,8 @@ func (c LoginCommand) RunContext(ctx context.Context, config *Config) error { return config.SaveOAuthToken(accessToken, idToken) } -func (c LoginCommand) Run(config *Config) error { - return c.RunContext(context.Background(), config) +func (c LoginCommand) Run(globals *Globals, config *Config) error { + return c.RunContext(context.Background(), globals, config) } var ErrNoPortsAvailable = errors.New("no ports available") diff --git a/command/root.go b/command/root.go index 496a611f..68a0a095 100644 --- a/command/root.go +++ b/command/root.go @@ -14,19 +14,25 @@ var ( FlagOIDCDomain = "oidc-domain" FlagClientID = "client-id" FlagConfigPath = "config" - FlagQuiet = "quiet" ) +type Globals struct { + OIDCDomain string `help:"The domain name of your OIDC server." hidden:"" env:"KEYCONJURER_OIDC_DOMAIN" default:"${oidc_domain}"` + ClientID string `help:"The client ID of your OIDC server." hidden:"" env:"KEYCONJURER_CLIENT_ID" default:"${client_id}"` + ConfigPath string `help:"The path to .keyconjurerrc file." default:"~/.keyconjurerrc" name:"config"` + Quiet bool `help:"Tells the CLI to be quiet; stdout will not contain human-readable informational messages."` +} + type CLI struct { - Login LoginCommand `cmd:"" help:"Authenticate with KeyConjurer."` - Get GetCommand `cmd:"" help:"Retrieve temporary cloud credentials."` - // Switch SwitchCommand `cmd:"" help:"Switch between accounts."` + Globals - ConfigPath string `help:"path to .keyconjurerrc file" default:"~/.keyconjurerrc" name:"config"` - Quiet bool `help:"tells the CLI to be quiet; stdout will not contain human-readable informational messages"` - Version kong.VersionFlag `help:"Show version information." short:"v" flag:""` + Login LoginCommand `cmd:"" help:"Authenticate with KeyConjurer."` + Accounts AccountsCommand `cmd:"" help:"Display accounts."` + Get GetCommand `cmd:"" help:"Retrieve temporary cloud credentials."` + // Switch SwitchCommand `cmd:"" help:"Switch between accounts."` - Config Config `kong:"-"` + Config Config `kong:"-"` + Version kong.VersionFlag `help:"Show version information." short:"v" flag:""` } func (CLI) Help() string { @@ -106,5 +112,5 @@ func Execute(args []string) error { return err } - return kongCtx.Run() + return kongCtx.Run(&cli.Globals) } From d6c61d0882de81c9349033a06adcfb22f9cc66cc Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Tue, 12 Nov 2024 12:45:14 -0800 Subject: [PATCH 16/30] Remove Cobra init blocks --- command/accounts.go | 8 +--- command/get.go | 40 +++--------------- command/login.go | 15 +++---- command/root.go | 8 +--- command/switch.go | 90 ++++++++-------------------------------- internal/api/settings.go | 10 ++--- main.go | 32 +++++++------- 7 files changed, 49 insertions(+), 154 deletions(-) diff --git a/command/accounts.go b/command/accounts.go index fa38b0a6..6223600f 100644 --- a/command/accounts.go +++ b/command/accounts.go @@ -14,11 +14,7 @@ import ( "golang.org/x/oauth2" ) -var ( - FlagNoRefresh = "no-refresh" - FlagServerAddress = "server-address" - ErrSessionExpired = errors.New("session expired") -) +var ErrSessionExpired = errors.New("session expired") type AccountsCommand struct { Refresh bool `help:"Refresh the list of accounts." default:"true" negatable:""` @@ -46,7 +42,7 @@ func (a AccountsCommand) RunContext(ctx context.Context, globals *Globals, confi if err != nil { return genericError{ ExitCode: ExitCodeValueError, - Message: fmt.Sprintf("--%s had an invalid value: %s\n", FlagServerAddress, err), + Message: fmt.Sprintf("server-address had an invalid value: %s\n", err), } } diff --git a/command/get.go b/command/get.go index d0290ebc..c0ce6123 100644 --- a/command/get.go +++ b/command/get.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "os" - "slices" "time" "github.com/aws/aws-sdk-go-v2/aws" @@ -13,24 +12,6 @@ import ( "github.com/riotgames/key-conjurer/oauth2" ) -var ( - FlagRegion = "region" - FlagRoleName = "role" - FlagTimeRemaining = "time-remaining" - FlagTimeToLive = "ttl" - FlagBypassCache = "bypass-cache" - FlagLogin = "login" -) - -var ( - // outputTypeEnvironmentVariable indicates that keyconjurer will dump the credentials to stdout in Bash environment variable format - outputTypeEnvironmentVariable = "env" - // outputTypeAWSCredentialsFile indicates that keyconjurer will dump the credentials into the ~/.aws/credentials file. - outputTypeAWSCredentialsFile = "awscli" - permittedOutputTypes = []string{outputTypeAWSCredentialsFile, outputTypeEnvironmentVariable} - permittedShellTypes = []string{shellTypePowershell, shellTypeBash, shellTypeBasic, shellTypeInfer} -) - func resolveApplicationInfo(cfg *Config, bypassCache bool, nameOrID string) (*Account, bool) { if bypassCache { return &Account{ID: nameOrID, Name: nameOrID}, true @@ -51,7 +32,7 @@ type GetCommand struct { RoleName string `help:"The name of the role to assume." short:"r" name:"role"` SessionName string `help:"The name of the role session name that will show up in CloudTrail logs." default:"KeyConjurer-AssumeRole"` Region string `help:"The AWS region to use." env:"AWS_REGION" default:"us-west-2"` - BypassCache bool `help:"Do not check the cache for accounts and send the application ID as-is to Okta. This is useful if you have an ID you know is an Okta application ID and it is not stored in your local account cache." hidden:""` + Cache bool `help:"Check the cache for accounts and send the application ID as-is to Okta. This is useful if you have an ID you know is an Okta application ID and it is not stored in your local account cache." default:"true" negatable:"" hidden:""` UsageFunc func() error `kong:"-"` PrintErrln func(...any) `kong:"-"` @@ -63,17 +44,6 @@ func (g GetCommand) Help() string { A role must be specified when using this command through the --role flag. You may list the roles you can assume through the roles command, and the accounts through the accounts command.` } -func (g GetCommand) Validate() error { - if !slices.Contains(permittedOutputTypes, g.OutputType) { - return ValueError{Value: g.OutputType, ValidValues: permittedOutputTypes} - } - - if !slices.Contains(permittedShellTypes, g.ShellType) { - return ValueError{Value: g.ShellType, ValidValues: permittedShellTypes} - } - return nil -} - func (g GetCommand) printUsage() error { return g.UsageFunc() } @@ -100,9 +70,9 @@ func (g GetCommand) RunContext(ctx context.Context, globals *Globals, cfg *Confi return g.printUsage() } - account, ok := resolveApplicationInfo(cfg, g.BypassCache, accountID) + account, ok := resolveApplicationInfo(cfg, !g.Cache, accountID) if !ok { - return UnknownAccountError(g.AccountNameOrID, FlagBypassCache) + return UnknownAccountError(g.AccountNameOrID, "--no-cache") } if g.RoleName == "" { @@ -188,10 +158,10 @@ func (g GetCommand) fetchNewCredentials(ctx context.Context, account Account, gl func echoCredentials(id, name string, credentials CloudCredentials, outputType, shellType, cliPath string) error { switch outputType { - case outputTypeEnvironmentVariable: + case "env": credentials.WriteFormat(os.Stdout, shellType) return nil - case outputTypeAWSCredentialsFile: + case "aws": acc := Account{ID: id, Name: name} newCliEntry := NewCloudCliEntry(credentials, &acc) return SaveCloudCredentialInCLI(cliPath, newCliEntry) diff --git a/command/login.go b/command/login.go index 51bf96c4..8bccf6ec 100644 --- a/command/login.go +++ b/command/login.go @@ -15,16 +15,6 @@ import ( "github.com/riotgames/key-conjurer/oauth2" ) -var ( - FlagURLOnly = "url-only" - FlagNoBrowser = "no-browser" -) - -func isPiped() bool { - fi, _ := os.Stdout.Stat() - return fi.Mode()&os.ModeCharDevice == 0 -} - type LoginCommand struct { URLOnly bool `help:"Print only the URL to visit rather than a user-friendly message." short:"u"` Browser bool `help:"Open the browser to the Okta URL. If false, a URL will be printed to the command line instead." default:"true" negatable:"" short:"b"` @@ -119,3 +109,8 @@ func openBrowserToURL(url string) error { slog.Debug("trying to open browser window", slog.String("url", url)) return browser.OpenURL(url) } + +func isPiped() bool { + fi, _ := os.Stdout.Stat() + return fi.Mode()&os.ModeCharDevice == 0 +} diff --git a/command/root.go b/command/root.go index 68a0a095..492aebf2 100644 --- a/command/root.go +++ b/command/root.go @@ -10,12 +10,6 @@ import ( "github.com/mitchellh/go-homedir" ) -var ( - FlagOIDCDomain = "oidc-domain" - FlagClientID = "client-id" - FlagConfigPath = "config" -) - type Globals struct { OIDCDomain string `help:"The domain name of your OIDC server." hidden:"" env:"KEYCONJURER_OIDC_DOMAIN" default:"${oidc_domain}"` ClientID string `help:"The client ID of your OIDC server." hidden:"" env:"KEYCONJURER_CLIENT_ID" default:"${client_id}"` @@ -29,7 +23,7 @@ type CLI struct { Login LoginCommand `cmd:"" help:"Authenticate with KeyConjurer."` Accounts AccountsCommand `cmd:"" help:"Display accounts."` Get GetCommand `cmd:"" help:"Retrieve temporary cloud credentials."` - // Switch SwitchCommand `cmd:"" help:"Switch between accounts."` + Switch SwitchCommand `cmd:"" help:"Switch between accounts."` Config Config `kong:"-"` Version kong.VersionFlag `help:"Show version information." short:"v" flag:""` diff --git a/command/switch.go b/command/switch.go index 35c0e08b..4ff0904c 100644 --- a/command/switch.go +++ b/command/switch.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "os" - "slices" "strings" "time" @@ -12,86 +11,27 @@ import ( "github.com/aws/aws-sdk-go-v2/aws/arn" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/sts" - "github.com/spf13/cobra" - "github.com/spf13/pflag" ) -var ( - FlagRoleSessionName = "role-session-name" - FlagOutputType = "out" - FlagShellType = "shell" - FlagAWSCLIPath = "awscli" -) - -func init() { - switchCmd.Flags().String(FlagRoleSessionName, "KeyConjurer-AssumeRole", "the name of the role session name that will show up in CloudTrail logs") - switchCmd.Flags().StringP(FlagOutputType, "o", outputTypeEnvironmentVariable, "Format to save new credentials in. Supported outputs: env, awscli") - switchCmd.Flags().String(FlagShellType, shellTypeInfer, "If output type is env, determines which format to output credentials in - by default, the format is inferred based on the execution environment. WSL users may wish to overwrite this to `bash`") - switchCmd.Flags().String(FlagAWSCLIPath, "~/.aws/", "Path for directory used by the aws-cli tool. Default is \"~/.aws\".") -} - -var switchCmd = cobra.Command{ - Use: "switch ", - Short: "Switch from the current AWS account into the one with the given Account ID.", - Long: `Attempt to AssumeRole into the given AWS with the current credentials. You only need to use this if you are a power user or network engineer with access to many accounts. - -This is used when a "bastion" account exists which users initially authenticate into and then pivot from that account into other accounts. - -This command will fail if you do not have active Cloud credentials. -`, - Example: "keyconjurer switch 123456798", - Args: cobra.ExactArgs(1), - Aliases: []string{"switch-account"}, - RunE: func(cmd *cobra.Command, args []string) error { - var switchCmd SwitchCommand - if err := switchCmd.Parse(cmd.Flags(), args); err != nil { - return err - } - - if err := switchCmd.Validate(); err != nil { - return err - } - - return switchCmd.Execute(cmd.Context()) - }, -} - type SwitchCommand struct { - OutputType string - ShellType string - AWSCLIPath string - RoleSessionName string - AccountID string + OutputType string `help:"Format to save new credentials in." default:"env" enum:"env,awscli" short:"o" default:"env" name:"out"` + AWSCLIPath string `help:"Path to the AWS CLI configuration directory." default:"~/.aws/" name:"awscli"` + ShellType string `name:"shell" help:"If output type is env, determines which format to output credentials in. WSL users may wish to overwrite this to \"bash\"." default:"infer" enum:"infer,basic,powershell,bash"` + SessionName string `help:"The name of the role session name that will show up in CloudTrail logs." default:"KeyConjurer-AssumeRole"` + AccountID string `arg:"" placeholder:"account-id"` } -func (s *SwitchCommand) Parse(flags *pflag.FlagSet, args []string) error { - s.OutputType, _ = flags.GetString(FlagOutputType) - s.ShellType, _ = flags.GetString(FlagShellType) - s.AWSCLIPath, _ = flags.GetString(FlagAWSCLIPath) - s.RoleSessionName, _ = flags.GetString(FlagRoleSessionName) - if len(args) == 0 { - return fmt.Errorf("account-id is required") - } - - s.AccountID = args[0] - return nil -} - -func (s SwitchCommand) Validate() error { - if !slices.Contains(permittedOutputTypes, s.OutputType) { - return ValueError{Value: s.OutputType, ValidValues: permittedOutputTypes} - } +func (SwitchCommand) Help() string { + return `Attempt to AssumeRole into the given AWS with the current credentials. You only need to use this if you are a power user or network engineer with access to many accounts. - if !slices.Contains(permittedShellTypes, s.ShellType) { - return ValueError{Value: s.ShellType, ValidValues: permittedShellTypes} - } +This is used when a "bastion" account exists which users initially authenticate into and then pivot from that account into other accounts. - return nil +This command will fail if you do not have active Cloud credentials.` } -func (s SwitchCommand) Execute(ctx context.Context) error { +func (s SwitchCommand) RunContext(ctx context.Context) error { // We could read the environment variable for the assumed role ARN, but it might be expired which isn't very useful to the user. - creds, err := getAWSCredentials(ctx, s.AccountID, s.RoleSessionName) + creds, err := getAWSCredentials(ctx, s.AccountID, s.SessionName) if err != nil { // If this failed, either there was a network error or the user is not authorized to assume into this role // This can happen if the user is not authenticated using the Bastion instance. @@ -99,10 +39,10 @@ func (s SwitchCommand) Execute(ctx context.Context) error { } switch s.OutputType { - case outputTypeEnvironmentVariable: + case "env": creds.WriteFormat(os.Stdout, s.ShellType) return nil - case outputTypeAWSCredentialsFile: + case "aws": acc := Account{ID: s.AccountID, Name: s.AccountID} newCliEntry := NewCloudCliEntry(creds, &acc) return SaveCloudCredentialInCLI(s.AWSCLIPath, newCliEntry) @@ -111,6 +51,10 @@ func (s SwitchCommand) Execute(ctx context.Context) error { } } +func (s SwitchCommand) Run() error { + return s.RunContext(context.Background()) +} + func getAWSCredentials(ctx context.Context, accountID, roleSessionName string) (creds CloudCredentials, err error) { cfg, err := config.LoadDefaultConfig(ctx) if err != nil { diff --git a/internal/api/settings.go b/internal/api/settings.go index c929f8a3..6c7971fd 100644 --- a/internal/api/settings.go +++ b/internal/api/settings.go @@ -14,14 +14,12 @@ type Settings struct { OktaToken string `json:"oktaToken"` } -var SettingsProviders = map[string]SettingsProvider{} - -func init() { - SettingsProviders["env"] = SettingsProviderFunc(RetrieveSettingsFromEnv) - SettingsProviders["vault"] = VaultRetriever{ +var SettingsProviders = map[string]SettingsProvider{ + "env": SettingsProviderFunc(RetrieveSettingsFromEnv), + "vault": VaultRetriever{ SecretMountPath: os.Getenv("KC_SECRET_MOUNT_PATH"), SecretPath: os.Getenv("KC_SECRET_PATH"), - } + }, } type SettingsProvider interface { diff --git a/main.go b/main.go index c64d0953..0c639ee8 100644 --- a/main.go +++ b/main.go @@ -13,20 +13,7 @@ import ( "github.com/spf13/cobra" ) -const ( - // WSAEACCES is the Windows error code for attempting to access a socket that you don't have permission to access. - // - // This commonly occurs if the socket is in use or was not closed correctly, and can be resolved by restarting the hns service. - WSAEACCES = 10013 -) - -// IsWindowsPortAccessError determines if the given error is the error WSAEACCES. -func IsWindowsPortAccessError(err error) bool { - var syscallErr *syscall.Errno - return errors.As(err, &syscallErr) && *syscallErr == WSAEACCES -} - -func init() { +func main() { var opts slog.HandlerOptions if os.Getenv("DEBUG") == "1" { opts.Level = slog.LevelDebug @@ -34,16 +21,14 @@ func init() { handler := slog.NewTextHandler(os.Stdout, &opts) slog.SetDefault(slog.New(handler)) -} -func main() { args := os.Args[1:] if flag, ok := os.LookupEnv("KEYCONJURERFLAGS"); ok { args = append(args, strings.Split(flag, " ")...) } err := command.Execute(args) - if IsWindowsPortAccessError(err) { + if isWindowsPortAccessError(err) { fmt.Fprintf(os.Stderr, "Encountered an issue when opening the port for KeyConjurer: %s\n", err) fmt.Fprintln(os.Stderr, "Consider running `net stop hns` and then `net start hns`") os.Exit(command.ExitCodeConnectivityError) @@ -59,3 +44,16 @@ func main() { os.Exit(errorCode) } } + +const ( + // wsaeacces is the Windows error code for attempting to access a socket that you don't have permission to access. + // + // This commonly occurs if the socket is in use or was not closed correctly, and can be resolved by restarting the hns service. + wsaeacces = 10013 +) + +// isWindowsPortAccessError determines if the given error is the error wsaeacces. +func isWindowsPortAccessError(err error) bool { + var syscallErr *syscall.Errno + return errors.As(err, &syscallErr) && *syscallErr == wsaeacces +} From ddb4a2613eab6a484b08820e51ea77d1a96e29fe Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Tue, 12 Nov 2024 12:52:41 -0800 Subject: [PATCH 17/30] Convert RolesCommand to Kong --- command/roles.go | 59 +++++++++++++++++++++++++++--------------------- command/root.go | 15 ++++++------ 2 files changed, 41 insertions(+), 33 deletions(-) diff --git a/command/roles.go b/command/roles.go index 57a16962..0992f641 100644 --- a/command/roles.go +++ b/command/roles.go @@ -1,43 +1,50 @@ package command import ( + "context" + "fmt" "strings" "github.com/RobotsAndPencils/go-saml" "github.com/riotgames/key-conjurer/oauth2" - "github.com/spf13/cobra" ) -var rolesCmd = cobra.Command{ - Use: "roles ", - Short: "Returns the roles that you have access to in the given account.", - Args: cobra.ExactArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - config := ConfigFromCommand(cmd) - if HasTokenExpired(config.Tokens) { - return ErrTokensExpiredOrAbsent - } +type RolesCommand struct { + ApplicationID string `arg:""` +} - oidcDomain, _ := cmd.Flags().GetString(FlagOIDCDomain) - clientID, _ := cmd.Flags().GetString(FlagClientID) +func (r RolesCommand) Run(globals *Globals, config *Config) error { + return r.RunContext(context.Background(), globals, config) +} - var applicationID = args[0] - account, ok := config.FindAccount(applicationID) - if ok { - applicationID = account.ID - } +func (r RolesCommand) RunContext(ctx context.Context, globals *Globals, config *Config) error { + if HasTokenExpired(config.Tokens) { + return ErrTokensExpiredOrAbsent + } - samlResponse, _, err := oauth2.DiscoverConfigAndExchangeTokenForAssertion(cmd.Context(), config.Tokens.AccessToken, config.Tokens.IDToken, oidcDomain, clientID, applicationID) - if err != nil { - return err - } + account, ok := config.FindAccount(r.ApplicationID) + if ok { + r.ApplicationID = account.ID + } - for _, name := range listRoles(samlResponse) { - cmd.Println(name) - } + samlResponse, _, err := oauth2.DiscoverConfigAndExchangeTokenForAssertion( + ctx, + config.Tokens.AccessToken, + config.Tokens.IDToken, + globals.OIDCDomain, + globals.ClientID, + r.ApplicationID, + ) + + if err != nil { + return err + } - return nil - }, + for _, name := range listRoles(samlResponse) { + fmt.Println(name) + } + + return nil } type roleProviderPair struct { diff --git a/command/root.go b/command/root.go index 492aebf2..da2028ef 100644 --- a/command/root.go +++ b/command/root.go @@ -20,13 +20,14 @@ type Globals struct { type CLI struct { Globals - Login LoginCommand `cmd:"" help:"Authenticate with KeyConjurer."` - Accounts AccountsCommand `cmd:"" help:"Display accounts."` - Get GetCommand `cmd:"" help:"Retrieve temporary cloud credentials."` - Switch SwitchCommand `cmd:"" help:"Switch between accounts."` - - Config Config `kong:"-"` - Version kong.VersionFlag `help:"Show version information." short:"v" flag:""` + Login LoginCommand `cmd:"" help:"Authenticate with KeyConjurer."` + Accounts AccountsCommand `cmd:"" help:"Display accounts."` + Get GetCommand `cmd:"" help:"Retrieve temporary cloud credentials."` + Switch SwitchCommand `cmd:"" help:"Switch between accounts."` + Roles RolesCommand `cmd:"" help:"Display roles for a specific account."` + Version kong.VersionFlag `help:"Show version information." short:"v" flag:""` + + Config Config `kong:"-"` } func (CLI) Help() string { From d4576ecc9bce2f5079b14ef673f14c2c66981432 Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Tue, 12 Nov 2024 16:58:31 -0800 Subject: [PATCH 18/30] Migrate alias --- command/alias.go | 40 +++++++++++++++++++++++++++------------- 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/command/alias.go b/command/alias.go index ad8ced7c..c280396a 100644 --- a/command/alias.go +++ b/command/alias.go @@ -1,16 +1,30 @@ package command -import ( - "github.com/spf13/cobra" -) +import "context" -var aliasCmd = cobra.Command{ - Use: "alias ", - Short: "Give an account a nickname.", - Long: "Alias an account to a nickname so you can refer to the account by the nickname.", - Args: cobra.ExactArgs(2), - Example: "keyconjurer alias FooAccount Bar", - Run: func(cmd *cobra.Command, args []string) { - config := ConfigFromCommand(cmd) - config.Alias(args[0], args[1]) - }} +type AliasCommand struct { + AccountName string `arg:""` + Alias string `arg:""` +} + +func (a AliasCommand) Run(globals *Globals, config *Config) error { + return a.RunContext(context.Background(), globals, config) +} + +func (a AliasCommand) RunContext(ctx context.Context, _ *Globals, config *Config) error { + config.Alias(a.AccountName, a.Alias) + return nil +} + +type UnaliasCommand struct { + Alias string `arg:""` +} + +func (a UnaliasCommand) Run(globals *Globals, config *Config) error { + return a.RunContext(context.Background(), globals, config) +} + +func (a UnaliasCommand) RunContext(ctx context.Context, _ *Globals, config *Config) error { + config.Unalias(a.Alias) + return nil +} From 0477eba2bb4286b710af3e9b7dd6d876b443c51a Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Tue, 12 Nov 2024 17:03:24 -0800 Subject: [PATCH 19/30] Add set commands --- command/root.go | 16 ++++++++----- command/set.go | 64 +++++++++++++++++++------------------------------ 2 files changed, 35 insertions(+), 45 deletions(-) diff --git a/command/root.go b/command/root.go index da2028ef..39953c34 100644 --- a/command/root.go +++ b/command/root.go @@ -20,12 +20,16 @@ type Globals struct { type CLI struct { Globals - Login LoginCommand `cmd:"" help:"Authenticate with KeyConjurer."` - Accounts AccountsCommand `cmd:"" help:"Display accounts."` - Get GetCommand `cmd:"" help:"Retrieve temporary cloud credentials."` - Switch SwitchCommand `cmd:"" help:"Switch between accounts."` - Roles RolesCommand `cmd:"" help:"Display roles for a specific account."` - Version kong.VersionFlag `help:"Show version information." short:"v" flag:""` + Login LoginCommand `cmd:"" help:"Authenticate with KeyConjurer."` + Accounts AccountsCommand `cmd:"" help:"Display accounts."` + Get GetCommand `cmd:"" help:"Retrieve temporary cloud credentials."` + Switch SwitchCommand `cmd:"" help:"Switch between accounts."` + Roles RolesCommand `cmd:"" help:"Display roles for a specific account."` + Alias AliasCommand `cmd:"" help:"Create an alias for an account."` + Unalias UnaliasCommand `cmd:"" help:"Remove an alias."` + Set SetCommand `cmd:"" help:"Set config values."` + + Version kong.VersionFlag `help:"Show version information." short:"v" flag:""` Config Config `kong:"-"` } diff --git a/command/set.go b/command/set.go index d071b0c4..a40b00ed 100644 --- a/command/set.go +++ b/command/set.go @@ -3,51 +3,37 @@ package command import ( "fmt" "strconv" - - "github.com/spf13/cobra" ) -func init() { - setCmd.AddCommand(setTTLCmd) - setCmd.AddCommand(setTimeRemainingCmd) +type SetCommand struct { + TTL TTLCommand `name:"ttl" help:"Sets ttl value in number of hours."` + TimeRemaining TimeRemainingCommand `name:"time-remaining" help:"Sets time remaining value in number of minutes."` } -var setCmd = &cobra.Command{ - Use: "set", - Short: "Sets config values.", - Long: "Sets config values.", +type TTLCommand struct { + TTL string `arg:"" help:"The ttl value in number of hours." placeholder:"hours"` } -var setTTLCmd = &cobra.Command{ - Use: "ttl ", - Short: "Sets ttl value in number of hours.", - Long: "Sets ttl value in number of hours.", - Args: cobra.ExactArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - config := ConfigFromCommand(cmd) - ttl, err := strconv.ParseUint(args[0], 10, 32) - if err != nil { - return fmt.Errorf("unable to parse value %s", args[0]) - } - - config.TTL = uint(ttl) - return nil - }, +func (t TTLCommand) Run(globals *Globals, config *Config) error { + ttl, err := strconv.ParseUint(t.TTL, 10, 32) + if err != nil { + return fmt.Errorf("unable to parse value %s", t.TTL) + } + + config.TTL = uint(ttl) + return nil } -var setTimeRemainingCmd = &cobra.Command{ - Use: "time-remaining ", - Short: "Sets time remaining value in number of minutes.", - Long: "Sets time remaining value in number of minutes. Using minutes is an artifact from when keys could only live for 1 hour.", - Args: cobra.ExactArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - config := ConfigFromCommand(cmd) - timeRemaining, err := strconv.ParseUint(args[0], 10, 32) - if err != nil { - return fmt.Errorf("unable to parse value %s", args[0]) - } - - config.TimeRemaining = uint(timeRemaining) - return nil - }, +type TimeRemainingCommand struct { + TimeRemaining string `arg:"" help:"The time remaining value in number of minutes." placeholder:"minutes"` +} + +func (t TimeRemainingCommand) Run(globals *Globals, config *Config) error { + timeRemaining, err := strconv.ParseUint(t.TimeRemaining, 10, 32) + if err != nil { + return fmt.Errorf("unable to parse value %s", t.TimeRemaining) + } + + config.TimeRemaining = uint(timeRemaining) + return nil } From 6b59dbd9f8908158b9b8abeafbe8f54f779fbc66 Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Tue, 12 Nov 2024 17:03:33 -0800 Subject: [PATCH 20/30] Remove Cobra --- command/context.go | 26 -------------------------- command/unalias.go | 15 --------------- go.mod | 5 +---- go.sum | 14 ++++++-------- main.go | 4 +--- 5 files changed, 8 insertions(+), 56 deletions(-) delete mode 100644 command/context.go delete mode 100644 command/unalias.go diff --git a/command/context.go b/command/context.go deleted file mode 100644 index 2329c49c..00000000 --- a/command/context.go +++ /dev/null @@ -1,26 +0,0 @@ -package command - -import ( - "context" - - "github.com/spf13/cobra" -) - -type configInfo struct { - Path string - Config *Config -} - -type ctxKeyConfig struct{} - -func ConfigFromCommand(cmd *cobra.Command) *Config { - return cmd.Context().Value(ctxKeyConfig{}).(*configInfo).Config -} - -func ConfigPathFromCommand(cmd *cobra.Command) string { - return cmd.Context().Value(ctxKeyConfig{}).(*configInfo).Path -} - -func ConfigContext(ctx context.Context, config *Config, path string) context.Context { - return context.WithValue(ctx, ctxKeyConfig{}, &configInfo{Path: path, Config: config}) -} diff --git a/command/unalias.go b/command/unalias.go deleted file mode 100644 index aef1ea05..00000000 --- a/command/unalias.go +++ /dev/null @@ -1,15 +0,0 @@ -package command - -import ( - "github.com/spf13/cobra" -) - -var unaliasCmd = cobra.Command{ - Use: "unalias ", - Short: "Remove alias from account.", - Args: cobra.ExactArgs(1), - Example: "keyconjurer unalias bar", - Run: func(cmd *cobra.Command, args []string) { - config := ConfigFromCommand(cmd) - config.Unalias(args[0]) - }} diff --git a/go.mod b/go.mod index d8cb376a..401fd4dc 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,7 @@ module github.com/riotgames/key-conjurer require ( github.com/RobotsAndPencils/go-saml v0.0.0-20170520135329-fb13cb52a46b + github.com/alecthomas/kong v1.4.0 github.com/aws/aws-lambda-go v1.47.0 github.com/aws/aws-sdk-go-v2 v1.32.4 github.com/aws/aws-sdk-go-v2/config v1.28.3 @@ -14,15 +15,12 @@ require ( github.com/mitchellh/go-ps v1.0.0 github.com/okta/okta-sdk-golang/v2 v2.2.1 github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 - github.com/spf13/cobra v1.7.0 - github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.8.4 golang.org/x/net v0.25.0 golang.org/x/oauth2 v0.24.0 ) require ( - github.com/alecthomas/kong v1.4.0 // indirect github.com/aws/aws-sdk-go-v2/credentials v1.17.44 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.19 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.23 // indirect @@ -45,7 +43,6 @@ require ( github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 // indirect github.com/hashicorp/go-sockaddr v1.0.2 // indirect github.com/hashicorp/hcl v1.0.0 // indirect - github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 // indirect github.com/kelseyhightower/envconfig v1.4.0 // indirect github.com/kr/text v0.2.0 // indirect diff --git a/go.sum b/go.sum index 20c27d1c..58ef70e4 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,11 @@ github.com/RobotsAndPencils/go-saml v0.0.0-20170520135329-fb13cb52a46b h1:EgJ6N2S0h1WfFIjU5/VVHWbMSVYXAluop97Qxpr/lfQ= github.com/RobotsAndPencils/go-saml v0.0.0-20170520135329-fb13cb52a46b/go.mod h1:3SAoF0F5EbcOuBD5WT9nYkbIJieBS84cUQXADbXeBsU= +github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0= +github.com/alecthomas/assert/v2 v2.11.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k= github.com/alecthomas/kong v1.4.0 h1:UL7tzGMnnY0YRMMvJyITIRX1EpO6RbBRZDNcCevy3HA= github.com/alecthomas/kong v1.4.0/go.mod h1:p2vqieVMeTAnaC83txKtXe8FLke2X07aruPWXyMPQrU= +github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc= +github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= github.com/aws/aws-lambda-go v1.47.0 h1:0H8s0vumYx/YKs4sE7YM0ktwL2eWse+kfopsRI1sXVI= github.com/aws/aws-lambda-go v1.47.0/go.mod h1:dpMpZgvWx5vuQJfBt0zqBha60q7Dd7RfgJv23DymV8A= @@ -37,7 +41,6 @@ github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK3 github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/coreos/go-oidc v2.2.1+incompatible h1:mh48q/BqXqgjVHpy2ZY7WnWAbenxRjsz9N1i1YxjHAk= github.com/coreos/go-oidc v2.2.1+incompatible/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHoZ1nMCKZlZ9V6mm3/LKc= -github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -86,8 +89,8 @@ github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/hashicorp/vault/api v1.15.0 h1:O24FYQCWwhwKnF7CuSqP30S51rTV7vz1iACXE/pj5DA= github.com/hashicorp/vault/api v1.15.0/go.mod h1:+5YTO09JGn0u+b6ySD/LLVf8WkJCPLAL2Vkmrn2+CM8= -github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= -github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= +github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/jarcoal/httpmock v1.0.6/go.mod h1:ATjnClrvW/3tijVmpL/va5Z3aAyGvqU3gCT8nX0Txik= github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= @@ -130,7 +133,6 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI= github.com/pquerna/cachecontrol v0.2.0 h1:vBXSNuE5MYP9IJ5kjsdo8uq+w41jSPgvba2DEnkRx9k= github.com/pquerna/cachecontrol v0.2.0/go.mod h1:NrUG3Z7Rdu85UNR3vm7SOsl1nFIeSiQnrHV5K9mBcUI= -github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/ryanuber/columnize v2.1.0+incompatible/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkBk= github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc= @@ -138,10 +140,6 @@ github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykE github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v0.0.0-20190330032615-68dc04aab96a h1:pa8hGb/2YqsZKovtsgrwcDH1RZhVbTKCjLp47XpqCDs= github.com/smartystreets/goconvey v0.0.0-20190330032615-68dc04aab96a/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= -github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= -github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= -github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= -github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= diff --git a/main.go b/main.go index 0c639ee8..34fd985d 100644 --- a/main.go +++ b/main.go @@ -10,7 +10,6 @@ import ( "log/slog" "github.com/riotgames/key-conjurer/command" - "github.com/spf13/cobra" ) func main() { @@ -35,8 +34,7 @@ func main() { } if err != nil { - cobra.CheckErr(err) - + fmt.Fprintln(os.Stderr, "Error:", err) errorCode, ok := command.GetExitCode(err) if !ok { errorCode = command.ExitCodeUnknownError From e154a574d9b5fcfbee33ce44220d0a3e4d07a089 Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Thu, 14 Nov 2024 12:46:49 -0800 Subject: [PATCH 21/30] Add TTL and Time Remaining cmd tags --- command/set.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/command/set.go b/command/set.go index a40b00ed..f877b072 100644 --- a/command/set.go +++ b/command/set.go @@ -6,8 +6,8 @@ import ( ) type SetCommand struct { - TTL TTLCommand `name:"ttl" help:"Sets ttl value in number of hours."` - TimeRemaining TimeRemainingCommand `name:"time-remaining" help:"Sets time remaining value in number of minutes."` + TTL TTLCommand `cmd:"" name:"ttl" help:"Sets ttl value in number of hours."` + TimeRemaining TimeRemainingCommand `cmd:"" name:"time-remaining" help:"Sets time remaining value in number of minutes."` } type TTLCommand struct { From 3be430187699f8404676ff556c995f2d9bf7c213 Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Tue, 3 Dec 2024 18:12:23 -0800 Subject: [PATCH 22/30] Generate state and verifier inside of the handler --- command/login.go | 15 +++++++-------- oauth2/oauth2.go | 34 +++++++++++++++++++++------------- 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/command/login.go b/command/login.go index 8bccf6ec..8febe50f 100644 --- a/command/login.go +++ b/command/login.go @@ -47,20 +47,19 @@ func (c LoginCommand) RunContext(ctx context.Context, globals *Globals, config * } oauthCfg.RedirectURL = fmt.Sprintf("http://%s", net.JoinHostPort("localhost", port)) - handler := oauth2.RedirectionFlowHandler{ - Config: oauthCfg, - OnDisplayURL: openBrowserToURL, - } - + handler := oauth2.AuthorizationCodeHandler{Config: oauthCfg} + session := handler.NewSession() if !c.Browser { if isPiped() || globals.Quiet { - handler.OnDisplayURL = printURLToConsole + printURLToConsole(session.URL()) } else { - handler.OnDisplayURL = friendlyPrintURLToConsole + friendlyPrintURLToConsole(session.URL()) } + } else { + browser.OpenURL(session.URL()) } - accessToken, idToken, err := handler.HandlePendingSession(ctx, sock, oauth2.GenerateState()) + accessToken, idToken, err := handler.WaitForToken(ctx, sock, session) if err != nil { return err } diff --git a/oauth2/oauth2.go b/oauth2/oauth2.go index 57a55330..9ef263ad 100644 --- a/oauth2/oauth2.go +++ b/oauth2/oauth2.go @@ -127,32 +127,40 @@ func (e OAuth2Error) Error() string { return fmt.Sprintf("oauth2 error: %s (%s)", e.Description, e.Reason) } -func GenerateState() string { +func generateState() string { stateBuf := make([]byte, stateBufSize) rand.Read(stateBuf) return base64.URLEncoding.EncodeToString(stateBuf) } -type RedirectionFlowHandler struct { - Config *oauth2.Config - OnDisplayURL func(url string) error +type Session struct { + url string + state string + verifier string } -func (r RedirectionFlowHandler) HandlePendingSession(ctx context.Context, listener net.Listener, state string) (*oauth2.Token, string, error) { - if r.OnDisplayURL == nil { - panic("OnDisplayURL must be set") - } +func (s Session) URL() string { + return s.url +} + +type AuthorizationCodeHandler struct { + Config *oauth2.Config + sessions map[string]Session +} +func (r *AuthorizationCodeHandler) NewSession() Session { + state := generateState() verifier := oauth2.GenerateVerifier() url := r.Config.AuthCodeURL(state, oauth2.S256ChallengeOption(verifier)) + s := Session{verifier: verifier, state: state, url: url} + r.sessions[state] = s + return s +} +func (r AuthorizationCodeHandler) WaitForToken(ctx context.Context, listener net.Listener, session Session) (*oauth2.Token, string, error) { ch := make(chan Callback, 1) // TODO: This error probably should not be ignored if it is not http.ErrServerClosed - go http.Serve(listener, OAuth2CallbackHandler(r.Config, state, verifier, ch)) - - if err := r.OnDisplayURL(url); err != nil { - return nil, "", fmt.Errorf("failed to display link: %w", err) - } + go http.Serve(listener, OAuth2CallbackHandler(r.Config, session.state, session.verifier, ch)) select { case info := <-ch: From 384ffc95a053a1adbdb7fd06ef0727117b322d87 Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Tue, 3 Dec 2024 19:24:59 -0800 Subject: [PATCH 23/30] Ensure the server is closed after receiving a request --- oauth2/oauth2.go | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/oauth2/oauth2.go b/oauth2/oauth2.go index 9ef263ad..122d69cd 100644 --- a/oauth2/oauth2.go +++ b/oauth2/oauth2.go @@ -86,6 +86,7 @@ type CodeExchanger interface { // The function may be called to ensure that the channel is closed. The channel is closed when a request is received. In general, it is a good idea to ensure this function is called in a defer() block. func OAuth2CallbackHandler(codeEx CodeExchanger, state, verifier string, ch chan<- Callback) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { + // This can sometimes be called multiple times, depending on the browser. // We will simply ignore any other requests and only serve the first. var info OAuth2CallbackState @@ -105,14 +106,16 @@ func OAuth2CallbackHandler(codeEx CodeExchanger, state, verifier string, ch chan return } + // Make sure to respond to the user right away. If we don't, + // the server may be closed before a response can be sent. + fmt.Fprintln(w, "You may close this window now.") + // https://openid.net/specs/openid-connect-core-1_0.html#TokenResponse if idToken, ok := token.Extra("id_token").(string); ok { ch <- Callback{Token: token, IDToken: &idToken} } else { ch <- Callback{Token: token} } - - fmt.Fprintln(w, "You may close this window now.") } return http.HandlerFunc(fn) @@ -159,18 +162,22 @@ func (r *AuthorizationCodeHandler) NewSession() Session { func (r AuthorizationCodeHandler) WaitForToken(ctx context.Context, listener net.Listener, session Session) (*oauth2.Token, string, error) { ch := make(chan Callback, 1) - // TODO: This error probably should not be ignored if it is not http.ErrServerClosed - go http.Serve(listener, OAuth2CallbackHandler(r.Config, session.state, session.verifier, ch)) + server := http.Server{ + Handler: OAuth2CallbackHandler(r.Config, session.state, session.verifier, ch), + } + + go server.Serve(listener) select { case info := <-ch: - // TODO: Close the server immediately to prevent any more requests being received. + server.Close() if info.Error != nil { return nil, "", info.Error } return info.Token, "", nil case <-ctx.Done(): + server.Close() return nil, "", ctx.Err() } } From c5dad369588bf73731bf0b161c327530152bd248 Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Fri, 6 Dec 2024 04:26:40 -0800 Subject: [PATCH 24/30] Use channels to clean up AuthorizationCodeHandler This change makes AuthorizationCodeHandler reentrant and prevents panics that could occur if two requests occurred before the server was closed. It also makes the API nicer. --- command/login.go | 25 +++++++-- oauth2/oauth2.go | 131 +++++++++++++++-------------------------------- 2 files changed, 60 insertions(+), 96 deletions(-) diff --git a/command/login.go b/command/login.go index 8febe50f..12d7d02f 100644 --- a/command/login.go +++ b/command/login.go @@ -47,7 +47,7 @@ func (c LoginCommand) RunContext(ctx context.Context, globals *Globals, config * } oauthCfg.RedirectURL = fmt.Sprintf("http://%s", net.JoinHostPort("localhost", port)) - handler := oauth2.AuthorizationCodeHandler{Config: oauthCfg} + handler := &oauth2.AuthorizationCodeHandler{Config: oauthCfg} session := handler.NewSession() if !c.Browser { if isPiped() || globals.Quiet { @@ -59,12 +59,27 @@ func (c LoginCommand) RunContext(ctx context.Context, globals *Globals, config * browser.OpenURL(session.URL()) } - accessToken, idToken, err := handler.WaitForToken(ctx, sock, session) - if err != nil { + errCh := make(chan error, 1) + go func() { + err := http.Serve(sock, handler) + if err != nil && !errors.Is(err, http.ErrServerClosed) { + errCh <- err + } + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errCh: return err + case err := <-session.Error: + return err + case token := <-session.Token: + // TODO Will panic if id_token not present + // TODO Verify token with OIDC provider + idToken := token.Extra("id_token").(string) + return config.SaveOAuthToken(token, idToken) } - - return config.SaveOAuthToken(accessToken, idToken) } func (c LoginCommand) Run(globals *Globals, config *Config) error { diff --git a/oauth2/oauth2.go b/oauth2/oauth2.go index 122d69cd..17879d67 100644 --- a/oauth2/oauth2.go +++ b/oauth2/oauth2.go @@ -6,9 +6,8 @@ import ( "encoding/base64" "errors" "fmt" - "net" "net/http" - "strings" + "sync" "github.com/RobotsAndPencils/go-saml" "github.com/coreos/go-oidc" @@ -54,73 +53,6 @@ func (o *OAuth2CallbackState) FromRequest(r *http.Request) { o.code = r.FormValue("code") } -// Verify safely compares the given state with the state from the OAuth2 callback. -// -// If they match, the code is returned, with a nil value. Otherwise, an empty string and an error is returned. -func (o OAuth2CallbackState) Verify(expectedState string) (string, error) { - if o.errorMessage != "" { - return "", OAuth2Error{Reason: o.errorMessage, Description: o.errorDescription} - } - - if strings.Compare(o.state, expectedState) != 0 { - return "", OAuth2Error{Reason: "invalid_state", Description: "state mismatch"} - } - - return o.code, nil -} - -type Callback struct { - Token *oauth2.Token - IDToken *string - Error error -} - -type CodeExchanger interface { - Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) -} - -// OAuth2CallbackHandler returns a http.Handler, channel and function triple. -// -// The http handler will accept exactly one request, which it will assume is an OAuth2 callback, parse it into an OAuth2CallbackState and then provide it to the given channel. Subsequent requests will be silently ignored. -// -// The function may be called to ensure that the channel is closed. The channel is closed when a request is received. In general, it is a good idea to ensure this function is called in a defer() block. -func OAuth2CallbackHandler(codeEx CodeExchanger, state, verifier string, ch chan<- Callback) http.Handler { - fn := func(w http.ResponseWriter, r *http.Request) { - - // This can sometimes be called multiple times, depending on the browser. - // We will simply ignore any other requests and only serve the first. - var info OAuth2CallbackState - info.FromRequest(r) - - code, err := info.Verify(state) - if err != nil { - ch <- Callback{Error: err} - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - - token, err := codeEx.Exchange(r.Context(), code, oauth2.VerifierOption(verifier)) - if err != nil { - ch <- Callback{Error: err} - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - // Make sure to respond to the user right away. If we don't, - // the server may be closed before a response can be sent. - fmt.Fprintln(w, "You may close this window now.") - - // https://openid.net/specs/openid-connect-core-1_0.html#TokenResponse - if idToken, ok := token.Extra("id_token").(string); ok { - ch <- Callback{Token: token, IDToken: &idToken} - } else { - ch <- Callback{Token: token} - } - } - - return http.HandlerFunc(fn) -} - type OAuth2Error struct { Reason string Description string @@ -140,6 +72,9 @@ type Session struct { url string state string verifier string + + Token chan *oauth2.Token + Error chan error } func (s Session) URL() string { @@ -147,39 +82,53 @@ func (s Session) URL() string { } type AuthorizationCodeHandler struct { - Config *oauth2.Config + Config *oauth2.Config + sessions map[string]Session + mu sync.Mutex } -func (r *AuthorizationCodeHandler) NewSession() Session { +func (h *AuthorizationCodeHandler) NewSession() Session { state := generateState() verifier := oauth2.GenerateVerifier() - url := r.Config.AuthCodeURL(state, oauth2.S256ChallengeOption(verifier)) - s := Session{verifier: verifier, state: state, url: url} - r.sessions[state] = s + url := h.Config.AuthCodeURL(state, oauth2.S256ChallengeOption(verifier)) + s := Session{verifier: verifier, state: state, url: url, Token: make(chan *oauth2.Token)} + h.mu.Lock() + defer h.mu.Unlock() + h.sessions[state] = s return s } -func (r AuthorizationCodeHandler) WaitForToken(ctx context.Context, listener net.Listener, session Session) (*oauth2.Token, string, error) { - ch := make(chan Callback, 1) - server := http.Server{ - Handler: OAuth2CallbackHandler(r.Config, session.state, session.verifier, ch), - } - - go server.Serve(listener) +func (h *AuthorizationCodeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + var info OAuth2CallbackState + info.FromRequest(r) - select { - case info := <-ch: - server.Close() - if info.Error != nil { - return nil, "", info.Error - } + // This lock is manually released in both branches, because if we defer() it, then it will get released + // after the Exchange() call. Exchange() can take a decent amount of time since it involves a remote call, + // and we don't want to hold the mutex lock for that long. + h.mu.Lock() + session, ok := h.sessions[info.state] + if !ok { + h.mu.Unlock() + http.Error(w, "no session", http.StatusBadRequest) + return + } + // Delete the session early so we can release the lock. + delete(h.sessions, info.state) + h.mu.Unlock() - return info.Token, "", nil - case <-ctx.Done(): - server.Close() - return nil, "", ctx.Err() + token, err := h.Config.Exchange(r.Context(), info.code, oauth2.VerifierOption(session.verifier)) + if err != nil { + session.Error <- err + http.Error(w, err.Error(), http.StatusInternalServerError) + return } + + // Make sure to respond to the user right away. If we don't, + // the server may be closed before a response can be sent. + fmt.Fprintln(w, "You may close this window now.") + session.Token <- token + close(session.Token) } func DiscoverConfigAndExchangeTokenForAssertion(ctx context.Context, accessToken, idToken, oidcDomain, clientID, applicationID string) (*saml.Response, string, error) { From 30cbd3c154963325939ac7e7bd27b8e8ef1a272c Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Fri, 6 Dec 2024 04:42:35 -0800 Subject: [PATCH 25/30] Reduce public API surface --- command/login.go | 2 +- oauth2/oauth2.go | 108 ++++++++++++++++++++---------------------- oauth2/oauth2_test.go | 45 +++++++++++++----- 3 files changed, 84 insertions(+), 71 deletions(-) diff --git a/command/login.go b/command/login.go index 12d7d02f..8e321b8c 100644 --- a/command/login.go +++ b/command/login.go @@ -47,7 +47,7 @@ func (c LoginCommand) RunContext(ctx context.Context, globals *Globals, config * } oauthCfg.RedirectURL = fmt.Sprintf("http://%s", net.JoinHostPort("localhost", port)) - handler := &oauth2.AuthorizationCodeHandler{Config: oauthCfg} + handler := oauth2.NewAuthorizationCodeHandler(oauthCfg) session := handler.NewSession() if !c.Browser { if isPiped() || globals.Quiet { diff --git a/oauth2/oauth2.go b/oauth2/oauth2.go index 17879d67..29076c72 100644 --- a/oauth2/oauth2.go +++ b/oauth2/oauth2.go @@ -35,31 +35,20 @@ func DiscoverConfig(ctx context.Context, domain, clientID string) (*oauth2.Confi return &cfg, nil } -// OAuth2CallbackState encapsulates all of the information from an oauth2 callback. -// -// To retrieve the Code from the struct, you must use the Verify(string) function. -type OAuth2CallbackState struct { +type callbackState struct { code string state string errorMessage string errorDescription string } -// FromRequest parses the given http.Request and populates the OAuth2CallbackState with those values. -func (o *OAuth2CallbackState) FromRequest(r *http.Request) { - o.errorMessage = r.FormValue("error") - o.errorDescription = r.FormValue("error_description") - o.state = r.FormValue("state") - o.code = r.FormValue("code") -} - -type OAuth2Error struct { - Reason string - Description string -} - -func (e OAuth2Error) Error() string { - return fmt.Sprintf("oauth2 error: %s (%s)", e.Description, e.Reason) +func parseOAuth2CallbackState(r *http.Request, info *callbackState) error { + err := r.ParseForm() + info.errorMessage = r.FormValue("error") + info.errorDescription = r.FormValue("error_description") + info.state = r.FormValue("state") + info.code = r.FormValue("code") + return err } func generateState() string { @@ -81,43 +70,61 @@ func (s Session) URL() string { return s.url } -type AuthorizationCodeHandler struct { - Config *oauth2.Config +// NewAuthorizationCodeHandler creates a new AuthorizationCodeHandler. +func NewAuthorizationCodeHandler(config *oauth2.Config) *AuthorizationCodeHandler { + return &AuthorizationCodeHandler{ + config: config, + sessions: make(map[string]Session), + } +} +// AuthorizationCodeHandler is an http.Handler that handles the OAuth2 authorization code flow. +// +// It is intended to be used by CLIs that need to authenticate with an OAuth2 provider. +// +// Sessions can be created using NewSession, and those sessions can be used to retrieve the OAuth2 token. +type AuthorizationCodeHandler struct { + config *oauth2.Config sessions map[string]Session mu sync.Mutex } -func (h *AuthorizationCodeHandler) NewSession() Session { +func (a *AuthorizationCodeHandler) NewSession() Session { state := generateState() verifier := oauth2.GenerateVerifier() - url := h.Config.AuthCodeURL(state, oauth2.S256ChallengeOption(verifier)) - s := Session{verifier: verifier, state: state, url: url, Token: make(chan *oauth2.Token)} - h.mu.Lock() - defer h.mu.Unlock() - h.sessions[state] = s + url := a.config.AuthCodeURL(state, oauth2.S256ChallengeOption(verifier)) + // A channel capacity of 1 is used to prevent requests from blocking if they are not actively being awaited on. + s := Session{verifier: verifier, state: state, url: url, Token: make(chan *oauth2.Token, 1)} + a.mu.Lock() + defer a.mu.Unlock() + a.sessions[state] = s return s } -func (h *AuthorizationCodeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - var info OAuth2CallbackState - info.FromRequest(r) +func (a *AuthorizationCodeHandler) removeSessionIfExists(state string) (Session, bool) { + a.mu.Lock() + defer a.mu.Unlock() + s, ok := a.sessions[state] + if ok { + delete(a.sessions, state) + } + return s, ok +} + +func (a *AuthorizationCodeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + var st callbackState + if err := parseOAuth2CallbackState(r, &st); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } - // This lock is manually released in both branches, because if we defer() it, then it will get released - // after the Exchange() call. Exchange() can take a decent amount of time since it involves a remote call, - // and we don't want to hold the mutex lock for that long. - h.mu.Lock() - session, ok := h.sessions[info.state] + session, ok := a.removeSessionIfExists(st.state) if !ok { - h.mu.Unlock() http.Error(w, "no session", http.StatusBadRequest) return } - // Delete the session early so we can release the lock. - delete(h.sessions, info.state) - h.mu.Unlock() - token, err := h.Config.Exchange(r.Context(), info.code, oauth2.VerifierOption(session.verifier)) + token, err := a.config.Exchange(r.Context(), st.code, oauth2.VerifierOption(session.verifier)) if err != nil { session.Error <- err http.Error(w, err.Error(), http.StatusInternalServerError) @@ -134,36 +141,23 @@ func (h *AuthorizationCodeHandler) ServeHTTP(w http.ResponseWriter, r *http.Requ func DiscoverConfigAndExchangeTokenForAssertion(ctx context.Context, accessToken, idToken, oidcDomain, clientID, applicationID string) (*saml.Response, string, error) { oauthCfg, err := DiscoverConfig(ctx, oidcDomain, clientID) if err != nil { - return nil, "", Error{Message: "could not discover oauth2 config", InnerError: err} + return nil, "", fmt.Errorf("could not discover oauth2 config: %w", err) } tok, err := exchangeAccessTokenForWebSSOToken(ctx, oauthCfg, accessToken, idToken, applicationID) if err != nil { - return nil, "", Error{Message: "error exchanging token", InnerError: err} + return nil, "", fmt.Errorf("error exchanging token: %w", err) } assertionBytes, err := exchangeWebSSOTokenForSAMLAssertion(ctx, oidcDomain, tok) if err != nil { - return nil, "", Error{Message: "failed to fetch SAML assertion", InnerError: err} + return nil, "", fmt.Errorf("failed to fetch SAML assertion: %w", err) } response, err := saml.ParseEncodedResponse(string(assertionBytes)) if err != nil { - return nil, "", Error{Message: "failed to parse SAML response", InnerError: err} + return nil, "", fmt.Errorf("failed to parse SAML response: %w", err) } return response, string(assertionBytes), nil } - -type Error struct { - InnerError error - Message string -} - -func (o Error) Unwrap() error { - return o.InnerError -} - -func (o Error) Error() string { - return o.Message -} diff --git a/oauth2/oauth2_test.go b/oauth2/oauth2_test.go index 400007cc..ea668618 100644 --- a/oauth2/oauth2_test.go +++ b/oauth2/oauth2_test.go @@ -1,7 +1,6 @@ package oauth2 import ( - "context" "net/http" "net/http/httptest" "net/url" @@ -10,6 +9,34 @@ import ( "golang.org/x/oauth2" ) +var cfg = &oauth2.Config{ + ClientID: "client-id", + Endpoint: oauth2.Endpoint{ + TokenURL: "http://localhost/oauth2/token", + }, +} + +type roundTripperFunc func(req *http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +var client = http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if req.URL.Path == "/oauth2/token" { + return &http.Response{ + StatusCode: http.StatusOK, + Body: nil, + }, nil + } + return &http.Response{ + StatusCode: http.StatusNotFound, + Body: nil, + }, nil + }), +} + func sendOAuth2CallbackRequest(handler http.Handler, values url.Values) { uri := url.URL{ Scheme: "http", @@ -23,18 +50,9 @@ func sendOAuth2CallbackRequest(handler http.Handler, values url.Values) { handler.ServeHTTP(w, req) } -type testCodeExchanger struct{} - -func (t *testCodeExchanger) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { - return nil, nil -} - -// Test_OAuth2Listener_MultipleRequestsDoesNotCausePanic prevents an issue where OAuth2Listener would send a request to a closed channel -func Test_OAuth2Listener_MultipleRequestsDoesNotCausePanic(t *testing.T) { - ch := make(chan Callback, 2) - defer close(ch) - - handler := OAuth2CallbackHandler(&testCodeExchanger{}, "state", "verifier", ch) +// Test_AuthorizationCodeHandler_IsReentrant prevents an issue where AuthorizationCodeHandler would send a request to a closed channel +func Test_AuthorizationCodeHandler_IsReentrant(t *testing.T) { + handler := NewAuthorizationCodeHandler(cfg) go sendOAuth2CallbackRequest(handler, url.Values{ // We send empty values because we don't care about processing in this test @@ -47,4 +65,5 @@ func Test_OAuth2Listener_MultipleRequestsDoesNotCausePanic(t *testing.T) { "code": []string{"not the expected code and should be discarded"}, "state": []string{"not the expected state and should be discarded"}, }) + // If we reach here with no panics, it should pass } From 18a4424974c6fcb81c32467ea31209c48d53341a Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Fri, 6 Dec 2024 05:00:25 -0800 Subject: [PATCH 26/30] Add a dedicated Okta package --- command/get.go | 8 +++++++- command/roles.go | 10 ++++++++-- oauth2/oauth2.go | 28 ---------------------------- {oauth2 => okta}/html.go | 2 +- {oauth2 => okta}/html_test.go | 2 +- {oauth2 => okta}/websso.go | 24 +++++++++++++++++++++++- 6 files changed, 40 insertions(+), 34 deletions(-) rename {oauth2 => okta}/html.go (99%) rename {oauth2 => okta}/html_test.go (99%) rename {oauth2 => okta}/websso.go (73%) diff --git a/command/get.go b/command/get.go index c0ce6123..99f7373c 100644 --- a/command/get.go +++ b/command/get.go @@ -10,6 +10,7 @@ import ( "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/riotgames/key-conjurer/oauth2" + "github.com/riotgames/key-conjurer/okta" ) func resolveApplicationInfo(cfg *Config, bypassCache bool, nameOrID string) (*Account, bool) { @@ -109,7 +110,12 @@ func (g GetCommand) Run(globals *Globals, cfg *Config) error { } func (g GetCommand) fetchNewCredentials(ctx context.Context, account Account, globals *Globals, cfg *Config) (*CloudCredentials, error) { - samlResponse, assertionStr, err := oauth2.DiscoverConfigAndExchangeTokenForAssertion(ctx, cfg.Tokens.AccessToken, cfg.Tokens.IDToken, globals.OIDCDomain, globals.ClientID, account.ID) + oauth2Cfg, err := oauth2.DiscoverConfig(ctx, globals.OIDCDomain, globals.ClientID) + if err != nil { + return nil, err + } + + samlResponse, assertionStr, err := okta.ExchangeTokenForAssertion(ctx, oauth2Cfg, cfg.Tokens.AccessToken, cfg.Tokens.IDToken, globals.OIDCDomain, account.ID) if err != nil { return nil, err } diff --git a/command/roles.go b/command/roles.go index 0992f641..237df219 100644 --- a/command/roles.go +++ b/command/roles.go @@ -7,6 +7,7 @@ import ( "github.com/RobotsAndPencils/go-saml" "github.com/riotgames/key-conjurer/oauth2" + "github.com/riotgames/key-conjurer/okta" ) type RolesCommand struct { @@ -27,12 +28,17 @@ func (r RolesCommand) RunContext(ctx context.Context, globals *Globals, config * r.ApplicationID = account.ID } - samlResponse, _, err := oauth2.DiscoverConfigAndExchangeTokenForAssertion( + cfg, err := oauth2.DiscoverConfig(ctx, globals.OIDCDomain, globals.ClientID) + if err != nil { + return err + } + + samlResponse, _, err := okta.ExchangeTokenForAssertion( ctx, + cfg, config.Tokens.AccessToken, config.Tokens.IDToken, globals.OIDCDomain, - globals.ClientID, r.ApplicationID, ) diff --git a/oauth2/oauth2.go b/oauth2/oauth2.go index 29076c72..3c951ef6 100644 --- a/oauth2/oauth2.go +++ b/oauth2/oauth2.go @@ -4,18 +4,14 @@ import ( "context" "crypto/rand" "encoding/base64" - "errors" "fmt" "net/http" "sync" - "github.com/RobotsAndPencils/go-saml" "github.com/coreos/go-oidc" "golang.org/x/oauth2" ) -var ErrNoSAMLAssertion = errors.New("no saml assertion") - // stateBufSize is the size of the buffer used to generate the state parameter. // 43 is a magic number - It generates states that are not too short or long for Okta's validation. const stateBufSize = 43 @@ -137,27 +133,3 @@ func (a *AuthorizationCodeHandler) ServeHTTP(w http.ResponseWriter, r *http.Requ session.Token <- token close(session.Token) } - -func DiscoverConfigAndExchangeTokenForAssertion(ctx context.Context, accessToken, idToken, oidcDomain, clientID, applicationID string) (*saml.Response, string, error) { - oauthCfg, err := DiscoverConfig(ctx, oidcDomain, clientID) - if err != nil { - return nil, "", fmt.Errorf("could not discover oauth2 config: %w", err) - } - - tok, err := exchangeAccessTokenForWebSSOToken(ctx, oauthCfg, accessToken, idToken, applicationID) - if err != nil { - return nil, "", fmt.Errorf("error exchanging token: %w", err) - } - - assertionBytes, err := exchangeWebSSOTokenForSAMLAssertion(ctx, oidcDomain, tok) - if err != nil { - return nil, "", fmt.Errorf("failed to fetch SAML assertion: %w", err) - } - - response, err := saml.ParseEncodedResponse(string(assertionBytes)) - if err != nil { - return nil, "", fmt.Errorf("failed to parse SAML response: %w", err) - } - - return response, string(assertionBytes), nil -} diff --git a/oauth2/html.go b/okta/html.go similarity index 99% rename from oauth2/html.go rename to okta/html.go index d01681e6..7d94df1b 100644 --- a/oauth2/html.go +++ b/okta/html.go @@ -1,4 +1,4 @@ -package oauth2 +package okta import ( "errors" diff --git a/oauth2/html_test.go b/okta/html_test.go similarity index 99% rename from oauth2/html_test.go rename to okta/html_test.go index 391fd65c..961a3b57 100644 --- a/oauth2/html_test.go +++ b/okta/html_test.go @@ -1,4 +1,4 @@ -package oauth2 +package okta import ( "strings" diff --git a/oauth2/websso.go b/okta/websso.go similarity index 73% rename from oauth2/websso.go rename to okta/websso.go index 379a0c10..1ae9ab28 100644 --- a/oauth2/websso.go +++ b/okta/websso.go @@ -1,4 +1,4 @@ -package oauth2 +package okta import ( "context" @@ -7,10 +7,13 @@ import ( "net/http" "net/url" + "github.com/RobotsAndPencils/go-saml" "golang.org/x/net/html" "golang.org/x/oauth2" ) +var ErrNoSAMLAssertion = errors.New("no saml assertion") + // exchangeAccessTokenForWebSSOToken exchanges an OAuth2 token for an Okta Web SSO token. // // An Okta Web SSO token is a non-standard authorization token for Okta's Web SSO endpoint. @@ -63,3 +66,22 @@ func exchangeWebSSOTokenForSAMLAssertion(ctx context.Context, issuer string, tok return []byte(saml), nil } + +func ExchangeTokenForAssertion(ctx context.Context, cfg *oauth2.Config, accessToken, idToken, oidcDomain, applicationID string) (*saml.Response, string, error) { + tok, err := exchangeAccessTokenForWebSSOToken(ctx, cfg, accessToken, idToken, applicationID) + if err != nil { + return nil, "", fmt.Errorf("error exchanging token: %w", err) + } + + assertionBytes, err := exchangeWebSSOTokenForSAMLAssertion(ctx, oidcDomain, tok) + if err != nil { + return nil, "", fmt.Errorf("failed to fetch SAML assertion: %w", err) + } + + response, err := saml.ParseEncodedResponse(string(assertionBytes)) + if err != nil { + return nil, "", fmt.Errorf("failed to parse SAML response: %w", err) + } + + return response, string(assertionBytes), nil +} From 50f9d4d13df7cdea9ef8f67d4e7ef555dc2bc8a3 Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Fri, 6 Dec 2024 05:07:12 -0800 Subject: [PATCH 27/30] Don't access oauth2 context directly --- okta/websso.go | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/okta/websso.go b/okta/websso.go index 1ae9ab28..2dfd3970 100644 --- a/okta/websso.go +++ b/okta/websso.go @@ -33,17 +33,12 @@ func exchangeAccessTokenForWebSSOToken(ctx context.Context, oauthCfg *oauth2.Con // exchangeWebSSOTokenForSAMLAssertion is an Okta-specific API which exchanges an Okta Web SSO token, which is obtained by exchanging an OAuth2 token using the RFC8693 Token Exchange Flow, for a SAML assertion. // // It is not standards compliant, but is used by Okta in their own okta-aws-cli. -func exchangeWebSSOTokenForSAMLAssertion(ctx context.Context, issuer string, token *oauth2.Token) ([]byte, error) { +func exchangeWebSSOTokenForSAMLAssertion(ctx context.Context, client *http.Client, issuer string, token *oauth2.Token) ([]byte, error) { data := url.Values{"token": {token.AccessToken}} uri := fmt.Sprintf("%s/login/token/sso?%s", issuer, data.Encode()) req, _ := http.NewRequestWithContext(ctx, "GET", uri, nil) req.Header.Add("Accept", "text/html") - client := http.DefaultClient - if val, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok { - client = val - } - resp, err := client.Do(req) if err != nil { return nil, err @@ -73,7 +68,7 @@ func ExchangeTokenForAssertion(ctx context.Context, cfg *oauth2.Config, accessTo return nil, "", fmt.Errorf("error exchanging token: %w", err) } - assertionBytes, err := exchangeWebSSOTokenForSAMLAssertion(ctx, oidcDomain, tok) + assertionBytes, err := exchangeWebSSOTokenForSAMLAssertion(ctx, cfg.Client(ctx, nil), oidcDomain, tok) if err != nil { return nil, "", fmt.Errorf("failed to fetch SAML assertion: %w", err) } From 005962381a44212e0a06582f8af6fc9724e5a252 Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Fri, 6 Dec 2024 06:12:36 -0800 Subject: [PATCH 28/30] Use os.UserConfigDir() for KeyConjurer configuration --- command/root.go | 29 +++++++++++++++++++---------- go.mod | 3 --- go.sum | 16 ---------------- 3 files changed, 19 insertions(+), 29 deletions(-) diff --git a/command/root.go b/command/root.go index 39953c34..f06c6de0 100644 --- a/command/root.go +++ b/command/root.go @@ -7,13 +7,11 @@ import ( "runtime" "github.com/alecthomas/kong" - "github.com/mitchellh/go-homedir" ) type Globals struct { OIDCDomain string `help:"The domain name of your OIDC server." hidden:"" env:"KEYCONJURER_OIDC_DOMAIN" default:"${oidc_domain}"` ClientID string `help:"The client ID of your OIDC server." hidden:"" env:"KEYCONJURER_CLIENT_ID" default:"${client_id}"` - ConfigPath string `help:"The path to .keyconjurerrc file." default:"~/.keyconjurerrc" name:"config"` Quiet bool `help:"Tells the CLI to be quiet; stdout will not contain human-readable informational messages."` } @@ -43,12 +41,22 @@ To get started run the following commands: keyconjurer get ` } +func getConfigPath() (string, error) { + cfgDir, err := os.UserConfigDir() + if err != nil { + return "", err + } + + return filepath.Join(cfgDir, "keyconjurer", "config.json"), nil +} + func (c *CLI) BeforeApply(ctx *kong.Context, trace *kong.Path) error { - if expanded, err := homedir.Expand(c.ConfigPath); err == nil { - c.ConfigPath = expanded + cfgPath, err := getConfigPath() + if err != nil { + return err } - file, err := EnsureConfigFileExists(c.ConfigPath) + file, err := EnsureConfigFileExists(cfgPath) if err != nil { return err } @@ -65,19 +73,20 @@ func (c *CLI) BeforeApply(ctx *kong.Context, trace *kong.Path) error { } func (c *CLI) AfterRun(ctx *kong.Context) error { - if expanded, err := homedir.Expand(c.ConfigPath); err == nil { - c.ConfigPath = expanded + cfgPath, err := getConfigPath() + if err != nil { + return err } // Do not use EnsureConfigFileExists here! EnsureConfigFileExists opens the file in append mode. // If we open the file in append mode, we'll always append to the file. If we open the file in truncate mode before reading from the file, the content will be truncated _before we read from it_, which will cause a users configuration to be discarded every time we run the program. - if err := os.MkdirAll(filepath.Dir(c.ConfigPath), os.ModeDir|os.FileMode(0755)); err != nil { + if err := os.MkdirAll(filepath.Dir(cfgPath), os.ModeDir|os.FileMode(0755)); err != nil { return err } - file, err := os.Create(c.ConfigPath) + file, err := os.Create(cfgPath) if err != nil { - return fmt.Errorf("unable to create %s reason: %w", c.ConfigPath, err) + return fmt.Errorf("unable to create %s reason: %w", cfgPath, err) } defer file.Close() diff --git a/go.mod b/go.mod index 401fd4dc..2f8cc266 100644 --- a/go.mod +++ b/go.mod @@ -33,7 +33,6 @@ require ( github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-jose/go-jose/v4 v4.0.1 // indirect - github.com/golang/protobuf v1.5.2 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect @@ -58,8 +57,6 @@ require ( golang.org/x/sys v0.20.0 // indirect golang.org/x/text v0.15.0 // indirect golang.org/x/time v0.0.0-20220922220347-f3bd1da661af // indirect - google.golang.org/appengine v1.6.7 // indirect - google.golang.org/protobuf v1.28.0 // indirect gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect gopkg.in/ini.v1 v1.42.0 // indirect gopkg.in/square/go-jose.v2 v2.5.1 // indirect diff --git a/go.sum b/go.sum index 58ef70e4..65318709 100644 --- a/go.sum +++ b/go.sum @@ -54,12 +54,7 @@ github.com/go-jose/go-jose/v4 v4.0.1 h1:QVEPDE3OluqXBQZDcnNvQrInro2h0e4eqNbnZSWq github.com/go-jose/go-jose/v4 v4.0.1/go.mod h1:WVf9LFMHh/QVrmqrOfqun0C45tMe3RoiKJMPvgWwLfY= github.com/go-test/deep v1.0.2 h1:onZX1rnHT3Wv6cqNgYyFOOlgVKJrksuCMCRvJStbMYw= github.com/go-test/deep v1.0.2/go.mod h1:wGDj63lr65AM2AQyKZd/NYHGb0R+1RLqB8NKt3aSFNA= -github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= -github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8= @@ -151,11 +146,8 @@ golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= -golang.org/x/oauth2 v0.6.0 h1:Lh8GPgSKBfWSwFvtuWOfeI3aAAnbXTSutYxJiOJFgIw= -golang.org/x/oauth2 v0.6.0/go.mod h1:ycmewcwgD4Rpr3eZJLSB4Kyyljb3qDh40vJ8STE5HKw= golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE= golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -165,20 +157,12 @@ golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/time v0.0.0-20220922220347-f3bd1da661af h1:Yx9k8YCG3dvF87UAn2tu2HQLf2dt/eR1bXxpLMWeH+Y= golang.org/x/time v0.0.0-20220922220347-f3bd1da661af/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= -google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= -google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= -google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= From a735c0861cb7fabe3fc7b912faf08807389678a6 Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Fri, 6 Dec 2024 06:23:24 -0800 Subject: [PATCH 29/30] Fallback to UserHomeDir() --- command/root.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/command/root.go b/command/root.go index f06c6de0..89a6cb75 100644 --- a/command/root.go +++ b/command/root.go @@ -44,7 +44,11 @@ To get started run the following commands: func getConfigPath() (string, error) { cfgDir, err := os.UserConfigDir() if err != nil { - return "", err + // UserConfigDir() and UserHomeDir() do slightly different things. If UserConfigDir() fails, try UserHomeDir() + cfgDir, err = os.UserHomeDir() + if err != nil { + return "", err + } } return filepath.Join(cfgDir, "keyconjurer", "config.json"), nil From 1753b207824e628fb8435537e5f6408e45459033 Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Tue, 10 Dec 2024 12:30:58 -0800 Subject: [PATCH 30/30] Fix a bug where malformed ~/.aws/config would prevent loading credentials --- command/get.go | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/command/get.go b/command/get.go index 99f7373c..c46df796 100644 --- a/command/get.go +++ b/command/get.go @@ -7,7 +7,6 @@ import ( "time" "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/riotgames/key-conjurer/oauth2" "github.com/riotgames/key-conjurer/okta" @@ -129,12 +128,7 @@ func (g GetCommand) fetchNewCredentials(ctx context.Context, account Account, gl g.TimeToLive = cfg.TTL } - awsCfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(g.Region)) - if err != nil { - return nil, err - } - - stsClient := sts.NewFromConfig(awsCfg) + stsClient := sts.New(sts.Options{Region: g.Region}) timeoutInSeconds := int32(3600 * g.TimeToLive) resp, err := stsClient.AssumeRoleWithSAML(ctx, &sts.AssumeRoleWithSAMLInput{ DurationSeconds: aws.Int32(timeoutInSeconds),