diff --git a/command/accounts.go b/command/accounts.go index 7227dbb8..6223600f 100644 --- a/command/accounts.go +++ b/command/accounts.go @@ -8,65 +8,60 @@ import ( "io" "net/http" "net/url" + "os" "github.com/riotgames/key-conjurer/internal/api" - "github.com/spf13/cobra" "golang.org/x/oauth2" ) -var ( - FlagNoRefresh = "no-refresh" - FlagServerAddress = "server-address" +var ErrSessionExpired = errors.New("session expired") - ErrSessionExpired = errors.New("session expired") -) +type AccountsCommand struct { + Refresh bool `help:"Refresh the list of accounts." default:"true" negatable:""` + ServerAddress string `help:"The address of the account server. This does not usually need to be changed or specified." hidden:"" env:"KEYCONJURER_SERVER_ADDRESS" default:"${server_address}"` +} -func init() { - accountsCmd.Flags().Bool(FlagNoRefresh, false, "Indicate that the account list should not be refreshed when executing this command. This is useful if you're not able to reach the account server.") - accountsCmd.Flags().String(FlagServerAddress, ServerAddress, "The address of the account server. This does not usually need to be changed or specified.") +func (a AccountsCommand) Help() string { + return "Prints and optionally refreshes the list of accounts you have access to." } -var accountsCmd = &cobra.Command{ - Use: "accounts", - Short: "Prints and optionally refreshes the list of accounts you have access to.", - RunE: func(cmd *cobra.Command, args []string) error { - config := ConfigFromCommand(cmd) - stdOut := cmd.OutOrStdout() - noRefresh, _ := cmd.Flags().GetBool(FlagNoRefresh) - loud := !ShouldUseMachineOutput(cmd.Flags()) - if noRefresh { - config.DumpAccounts(stdOut, loud) - - if loud { - // intentionally uses PrintErrf was a warning - cmd.PrintErrf("--%s was specified - these results may be out of date, and you may not have access to accounts in this list.\n", FlagNoRefresh) - } - - return nil - } +func (a AccountsCommand) RunContext(ctx context.Context, globals *Globals, config *Config) error { + loud := isPiped() || globals.Quiet + if !a.Refresh { + config.DumpAccounts(os.Stdout, loud) - serverAddr, _ := cmd.Flags().GetString(FlagServerAddress) - serverAddrURI, err := url.Parse(serverAddr) - if err != nil { - return genericError{ - ExitCode: ExitCodeValueError, - Message: fmt.Sprintf("--%s had an invalid value: %s\n", FlagServerAddress, err), - } + if loud { + // intentionally uses Fprintf was a warning + fmt.Fprintf(os.Stderr, "--no-refresh was specified - these results may be out of date, and you may not have access to accounts in this list.\n") } - if HasTokenExpired(config.Tokens) { - return ErrTokensExpiredOrAbsent - } + return nil + } - accounts, err := refreshAccounts(cmd.Context(), serverAddrURI, config.Tokens) - if err != nil { - return fmt.Errorf("error refreshing accounts: %w", err) + serverAddrURI, err := url.Parse(a.ServerAddress) + if err != nil { + return genericError{ + ExitCode: ExitCodeValueError, + Message: fmt.Sprintf("server-address had an invalid value: %s\n", err), } + } - config.UpdateAccounts(accounts) - config.DumpAccounts(stdOut, loud) - return nil - }, + if HasTokenExpired(config.Tokens) { + return ErrTokensExpiredOrAbsent + } + + accounts, err := refreshAccounts(ctx, serverAddrURI, config.Tokens) + if err != nil { + return fmt.Errorf("error refreshing accounts: %w", err) + } + + config.UpdateAccounts(accounts) + config.DumpAccounts(os.Stdout, loud) + return nil +} + +func (a AccountsCommand) Run(globals *Globals, config *Config) error { + return a.RunContext(context.Background(), globals, config) } func refreshAccounts(ctx context.Context, serverAddr *url.URL, ts oauth2.TokenSource) ([]Account, error) { diff --git a/command/alias.go b/command/alias.go index ad8ced7c..c280396a 100644 --- a/command/alias.go +++ b/command/alias.go @@ -1,16 +1,30 @@ package command -import ( - "github.com/spf13/cobra" -) +import "context" -var aliasCmd = cobra.Command{ - Use: "alias ", - Short: "Give an account a nickname.", - Long: "Alias an account to a nickname so you can refer to the account by the nickname.", - Args: cobra.ExactArgs(2), - Example: "keyconjurer alias FooAccount Bar", - Run: func(cmd *cobra.Command, args []string) { - config := ConfigFromCommand(cmd) - config.Alias(args[0], args[1]) - }} +type AliasCommand struct { + AccountName string `arg:""` + Alias string `arg:""` +} + +func (a AliasCommand) Run(globals *Globals, config *Config) error { + return a.RunContext(context.Background(), globals, config) +} + +func (a AliasCommand) RunContext(ctx context.Context, _ *Globals, config *Config) error { + config.Alias(a.AccountName, a.Alias) + return nil +} + +type UnaliasCommand struct { + Alias string `arg:""` +} + +func (a UnaliasCommand) Run(globals *Globals, config *Config) error { + return a.RunContext(context.Background(), globals, config) +} + +func (a UnaliasCommand) RunContext(ctx context.Context, _ *Globals, config *Config) error { + config.Unalias(a.Alias) + return nil +} diff --git a/command/context.go b/command/context.go deleted file mode 100644 index 2329c49c..00000000 --- a/command/context.go +++ /dev/null @@ -1,26 +0,0 @@ -package command - -import ( - "context" - - "github.com/spf13/cobra" -) - -type configInfo struct { - Path string - Config *Config -} - -type ctxKeyConfig struct{} - -func ConfigFromCommand(cmd *cobra.Command) *Config { - return cmd.Context().Value(ctxKeyConfig{}).(*configInfo).Config -} - -func ConfigPathFromCommand(cmd *cobra.Command) string { - return cmd.Context().Value(ctxKeyConfig{}).(*configInfo).Path -} - -func ConfigContext(ctx context.Context, config *Config, path string) context.Context { - return context.WithValue(ctx, ctxKeyConfig{}, &configInfo{Path: path, Config: config}) -} diff --git a/command/get.go b/command/get.go index 751118cb..c46df796 100644 --- a/command/get.go +++ b/command/get.go @@ -4,49 +4,14 @@ import ( "context" "fmt" "os" - "slices" "time" "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/riotgames/key-conjurer/oauth2" - "github.com/spf13/cobra" + "github.com/riotgames/key-conjurer/okta" ) -var ( - FlagRegion = "region" - FlagRoleName = "role" - FlagTimeRemaining = "time-remaining" - FlagTimeToLive = "ttl" - FlagBypassCache = "bypass-cache" - FlagLogin = "login" -) - -var ( - // outputTypeEnvironmentVariable indicates that keyconjurer will dump the credentials to stdout in Bash environment variable format - outputTypeEnvironmentVariable = "env" - // outputTypeAWSCredentialsFile indicates that keyconjurer will dump the credentials into the ~/.aws/credentials file. - outputTypeAWSCredentialsFile = "awscli" - permittedOutputTypes = []string{outputTypeAWSCredentialsFile, outputTypeEnvironmentVariable} - permittedShellTypes = []string{shellTypePowershell, shellTypeBash, shellTypeBasic, shellTypeInfer} -) - -func init() { - getCmd.Flags().String(FlagRegion, "us-west-2", "The AWS region to use") - getCmd.Flags().Uint(FlagTimeToLive, 1, "The key timeout in hours from 1 to 8.") - getCmd.Flags().UintP(FlagTimeRemaining, "t", DefaultTimeRemaining, "Request new keys if there are no keys in the environment or the current keys expire within minutes. Defaults to 60.") - getCmd.Flags().StringP(FlagRoleName, "r", "", "The name of the role to assume.") - getCmd.Flags().String(FlagRoleSessionName, "KeyConjurer-AssumeRole", "the name of the role session name that will show up in CloudTrail logs") - getCmd.Flags().StringP(FlagOutputType, "o", outputTypeEnvironmentVariable, "Format to save new credentials in. Supported outputs: env, awscli") - getCmd.Flags().String(FlagShellType, shellTypeInfer, "If output type is env, determines which format to output credentials in - by default, the format is inferred based on the execution environment. WSL users may wish to overwrite this to `bash`") - getCmd.Flags().Bool(FlagBypassCache, false, "Do not check the cache for accounts and send the application ID as-is to Okta. This is useful if you have an ID you know is an Okta application ID and it is not stored in your local account cache.") - getCmd.Flags().Bool(FlagLogin, false, "Login to Okta before running the command") - getCmd.Flags().String(FlagAWSCLIPath, "~/.aws/", "Path for directory used by the aws CLI") - getCmd.Flags().BoolP(FlagURLOnly, "u", false, "Print only the URL to visit rather than a user-friendly message") - getCmd.Flags().BoolP(FlagNoBrowser, "b", false, "Do not open a browser window, printing the URL instead") -} - func resolveApplicationInfo(cfg *Config, bypassCache bool, nameOrID string) (*Account, bool) { if bypassCache { return &Account{ID: nameOrID, Name: nameOrID}, true @@ -55,87 +20,59 @@ func resolveApplicationInfo(cfg *Config, bypassCache bool, nameOrID string) (*Ac } type GetCommand struct { - AccountIDOrName string - TimeToLive uint - TimeRemaining uint - OutputType, ShellType, RoleName, AWSCLIPath, OIDCDomain, ClientID, Region string - Login, URLOnly, NoBrowser, BypassCache, MachineOutput bool - - UsageFunc func() error - PrintErrln func(...any) -} - -func (g *GetCommand) Parse(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() - g.OIDCDomain, _ = flags.GetString(FlagOIDCDomain) - g.ClientID, _ = flags.GetString(FlagClientID) - g.TimeToLive, _ = flags.GetUint(FlagTimeToLive) - g.TimeRemaining, _ = flags.GetUint(FlagTimeRemaining) - g.OutputType, _ = flags.GetString(FlagOutputType) - g.ShellType, _ = flags.GetString(FlagShellType) - g.RoleName, _ = flags.GetString(FlagRoleName) - g.AWSCLIPath, _ = flags.GetString(FlagAWSCLIPath) - g.Login, _ = flags.GetBool(FlagLogin) - g.URLOnly, _ = flags.GetBool(FlagURLOnly) - g.NoBrowser, _ = flags.GetBool(FlagNoBrowser) - g.BypassCache, _ = flags.GetBool(FlagBypassCache) - g.Region, _ = flags.GetString(FlagRegion) - g.UsageFunc = cmd.Usage - g.PrintErrln = cmd.PrintErrln - g.MachineOutput = ShouldUseMachineOutput(flags) || g.URLOnly - if len(args) == 0 { - return fmt.Errorf("account name or alias is required") - } - g.AccountIDOrName = args[0] - return nil + AccountNameOrID string `arg:""` + TimeToLive uint `placeholder:"hours" help:"The key timeout in hours from 1 to 8." default:"1" name:"ttl"` + TimeRemaining uint `placeholder:"minutes" help:"Request new keys if there are no keys in the environment or the current keys expire within minutes." default:"5" short:"t"` + AWSCLIPath string `help:"Path to the AWS CLI configuration directory." default:"~/.aws/" name:"awscli"` + ShellType string `name:"shell" help:"If output type is env, determines which format to output credentials in. WSL users may wish to overwrite this to \"bash\"." default:"infer" enum:"infer,basic,powershell,bash"` + URLOnly bool `help:"Print only the URL to visit rather than a user-friendly message." short:"u"` + Browser bool `help:"Open the browser to the Okta URL. If false, a URL will be printed to the command line instead." default:"true" negatable:"" short:"b"` + OutputType string `help:"Format to save new credentials in." default:"env" enum:"env,awscli" short:"o" default:"env" name:"out"` + Login bool `help:"Login to Okta before running the command if the tokens have expired."` + RoleName string `help:"The name of the role to assume." short:"r" name:"role"` + SessionName string `help:"The name of the role session name that will show up in CloudTrail logs." default:"KeyConjurer-AssumeRole"` + Region string `help:"The AWS region to use." env:"AWS_REGION" default:"us-west-2"` + Cache bool `help:"Check the cache for accounts and send the application ID as-is to Okta. This is useful if you have an ID you know is an Okta application ID and it is not stored in your local account cache." default:"true" negatable:"" hidden:""` + + UsageFunc func() error `kong:"-"` + PrintErrln func(...any) `kong:"-"` } -func (g GetCommand) Validate() error { - if !slices.Contains(permittedOutputTypes, g.OutputType) { - return ValueError{Value: g.OutputType, ValidValues: permittedOutputTypes} - } +func (g GetCommand) Help() string { + return `Retrieves temporary cloud API credentials for the specified account. - if !slices.Contains(permittedShellTypes, g.ShellType) { - return ValueError{Value: g.ShellType, ValidValues: permittedShellTypes} - } - return nil +A role must be specified when using this command through the --role flag. You may list the roles you can assume through the roles command, and the accounts through the accounts command.` } func (g GetCommand) printUsage() error { return g.UsageFunc() } -func (g GetCommand) Execute(ctx context.Context, config *Config) error { - if HasTokenExpired(config.Tokens) { +func (g GetCommand) RunContext(ctx context.Context, globals *Globals, cfg *Config) error { + if HasTokenExpired(cfg.Tokens) { if !g.Login { return ErrTokensExpiredOrAbsent } - loginCommand := LoginCommand{ - OIDCDomain: g.OIDCDomain, - ClientID: g.ClientID, - MachineOutput: g.MachineOutput, - NoBrowser: g.NoBrowser, - } - - if err := loginCommand.Execute(ctx, config); err != nil { + var loginCommand LoginCommand + if err := loginCommand.RunContext(ctx, globals, cfg); err != nil { return err } } var accountID string - if g.AccountIDOrName != "" { - accountID = g.AccountIDOrName - } else if config.LastUsedAccount != nil { + if g.AccountNameOrID != "" { + accountID = g.AccountNameOrID + } else if cfg.LastUsedAccount != nil { // No account specified. Can we use the most recent one? - accountID = *config.LastUsedAccount + accountID = *cfg.LastUsedAccount } else { return g.printUsage() } - account, ok := resolveApplicationInfo(config, g.BypassCache, accountID) + account, ok := resolveApplicationInfo(cfg, !g.Cache, accountID) if !ok { - return UnknownAccountError(g.AccountIDOrName, FlagBypassCache) + return UnknownAccountError(g.AccountNameOrID, "--no-cache") } if g.RoleName == "" { @@ -146,13 +83,13 @@ func (g GetCommand) Execute(ctx context.Context, config *Config) error { g.RoleName = account.MostRecentRole } - if config.TimeRemaining != 0 && g.TimeRemaining == DefaultTimeRemaining { - g.TimeRemaining = config.TimeRemaining + if cfg.TimeRemaining != 0 && g.TimeRemaining == DefaultTimeRemaining { + g.TimeRemaining = cfg.TimeRemaining } credentials := LoadAWSCredentialsFromEnvironment() if !credentials.ValidUntil(account, time.Duration(g.TimeRemaining)*time.Minute) { - newCredentials, err := g.fetchNewCredentials(ctx, *account, config) + newCredentials, err := g.fetchNewCredentials(ctx, *account, globals, cfg) if err != nil { return err } @@ -163,31 +100,35 @@ func (g GetCommand) Execute(ctx context.Context, config *Config) error { account.MostRecentRole = g.RoleName } - config.LastUsedAccount = &accountID + cfg.LastUsedAccount = &accountID return echoCredentials(accountID, accountID, credentials, g.OutputType, g.ShellType, g.AWSCLIPath) } -func (g GetCommand) fetchNewCredentials(ctx context.Context, account Account, cfg *Config) (*CloudCredentials, error) { - samlResponse, assertionStr, err := oauth2.DiscoverConfigAndExchangeTokenForAssertion(ctx, cfg.Tokens.AccessToken, cfg.Tokens.IDToken, g.OIDCDomain, g.ClientID, account.ID) +func (g GetCommand) Run(globals *Globals, cfg *Config) error { + return g.RunContext(context.Background(), globals, cfg) +} + +func (g GetCommand) fetchNewCredentials(ctx context.Context, account Account, globals *Globals, cfg *Config) (*CloudCredentials, error) { + oauth2Cfg, err := oauth2.DiscoverConfig(ctx, globals.OIDCDomain, globals.ClientID) + if err != nil { + return nil, err + } + + samlResponse, assertionStr, err := okta.ExchangeTokenForAssertion(ctx, oauth2Cfg, cfg.Tokens.AccessToken, cfg.Tokens.IDToken, globals.OIDCDomain, account.ID) if err != nil { return nil, err } pair, ok := findRoleInSAML(g.RoleName, samlResponse) if !ok { - return nil, UnknownRoleError(g.RoleName, g.AccountIDOrName) + return nil, UnknownRoleError(g.RoleName, g.AccountNameOrID) } if g.TimeToLive == 1 && cfg.TTL != 0 { g.TimeToLive = cfg.TTL } - awsCfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(g.Region)) - if err != nil { - return nil, err - } - - stsClient := sts.NewFromConfig(awsCfg) + stsClient := sts.New(sts.Options{Region: g.Region}) timeoutInSeconds := int32(3600 * g.TimeToLive) resp, err := stsClient.AssumeRoleWithSAML(ctx, &sts.AssumeRoleWithSAMLInput{ DurationSeconds: aws.Int32(timeoutInSeconds), @@ -215,32 +156,12 @@ func (g GetCommand) fetchNewCredentials(ctx context.Context, account Account, cf }, nil } -var getCmd = &cobra.Command{ - Use: "get ", - Short: "Retrieves temporary cloud API credentials.", - Long: `Retrieves temporary cloud API credentials for the specified account. It sends a push request to the first Duo device it finds associated with your account. - -A role must be specified when using this command through the --role flag. You may list the roles you can assume through the roles command.`, - RunE: func(cmd *cobra.Command, args []string) error { - var getCmd GetCommand - if err := getCmd.Parse(cmd, args); err != nil { - return err - } - - if err := getCmd.Validate(); err != nil { - return err - } - - return getCmd.Execute(cmd.Context(), ConfigFromCommand(cmd)) - }, -} - func echoCredentials(id, name string, credentials CloudCredentials, outputType, shellType, cliPath string) error { switch outputType { - case outputTypeEnvironmentVariable: + case "env": credentials.WriteFormat(os.Stdout, shellType) return nil - case outputTypeAWSCredentialsFile: + case "aws": acc := Account{ID: id, Name: name} newCliEntry := NewCloudCliEntry(credentials, &acc) return SaveCloudCredentialInCLI(cliPath, newCliEntry) diff --git a/command/login.go b/command/login.go index b6c69233..8e321b8c 100644 --- a/command/login.go +++ b/command/login.go @@ -5,72 +5,32 @@ import ( "errors" "fmt" "net" + "net/http" "os" "log/slog" + "github.com/coreos/go-oidc" "github.com/pkg/browser" "github.com/riotgames/key-conjurer/oauth2" - "github.com/spf13/cobra" - "github.com/spf13/pflag" ) -var ( - FlagURLOnly = "url-only" - FlagNoBrowser = "no-browser" -) - -func init() { - loginCmd.Flags().BoolP(FlagURLOnly, "u", false, "Print only the URL to visit rather than a user-friendly message") - loginCmd.Flags().BoolP(FlagNoBrowser, "b", false, "Do not open a browser window, printing the URL instead") -} - -var loginCmd = &cobra.Command{ - Use: "login", - Short: "Authenticate with KeyConjurer.", - Long: "Login to KeyConjurer using OAuth2. You will be required to open the URL printed to the console or scan a QR code.", - RunE: func(cmd *cobra.Command, args []string) error { - var loginCmd LoginCommand - if err := loginCmd.Parse(cmd.Flags(), args); err != nil { - return err - } - - return loginCmd.Execute(cmd.Context(), ConfigFromCommand(cmd)) - }, -} - -// ShouldUseMachineOutput indicates whether or not we should write to standard output as if the user is a machine. -// -// What this means is implementation specific, but this usually indicates the user is trying to use this program in a script and we should avoid user-friendly output messages associated with values a user might find useful. -func ShouldUseMachineOutput(flags *pflag.FlagSet) bool { - quiet, _ := flags.GetBool(FlagQuiet) - fi, _ := os.Stdout.Stat() - isPiped := fi.Mode()&os.ModeCharDevice == 0 - return isPiped || quiet -} - type LoginCommand struct { - OIDCDomain string - ClientID string - MachineOutput bool - NoBrowser bool + URLOnly bool `help:"Print only the URL to visit rather than a user-friendly message." short:"u"` + Browser bool `help:"Open the browser to the Okta URL. If false, a URL will be printed to the command line instead." default:"true" negatable:"" short:"b"` } -func (c *LoginCommand) Parse(flags *pflag.FlagSet, args []string) error { - c.OIDCDomain, _ = flags.GetString(FlagOIDCDomain) - c.ClientID, _ = flags.GetString(FlagClientID) - c.NoBrowser, _ = flags.GetBool(FlagNoBrowser) - urlOnly, _ := flags.GetBool(FlagURLOnly) - c.MachineOutput = ShouldUseMachineOutput(flags) || urlOnly - return nil +func (c LoginCommand) Help() string { + return "Login to KeyConjurer using OAuth2. You will be required to open the URL printed to the console or scan a QR code." } -func (c LoginCommand) Execute(ctx context.Context, config *Config) error { +func (c LoginCommand) RunContext(ctx context.Context, globals *Globals, config *Config) error { if !HasTokenExpired(config.Tokens) { return nil } - oauthCfg, err := oauth2.DiscoverConfig(ctx, c.OIDCDomain, c.ClientID) + client := &http.Client{Transport: LogRoundTripper{http.DefaultTransport}} + oauthCfg, err := oauth2.DiscoverConfig(oidc.ClientContext(ctx, client), globals.OIDCDomain, globals.ClientID) if err != nil { return err } @@ -87,31 +47,43 @@ func (c LoginCommand) Execute(ctx context.Context, config *Config) error { } oauthCfg.RedirectURL = fmt.Sprintf("http://%s", net.JoinHostPort("localhost", port)) - handler := oauth2.RedirectionFlowHandler{ - Config: oauthCfg, - OnDisplayURL: openBrowserToURL, - } - - if c.NoBrowser { - if c.MachineOutput { - handler.OnDisplayURL = printURLToConsole + handler := oauth2.NewAuthorizationCodeHandler(oauthCfg) + session := handler.NewSession() + if !c.Browser { + if isPiped() || globals.Quiet { + printURLToConsole(session.URL()) } else { - handler.OnDisplayURL = friendlyPrintURLToConsole + friendlyPrintURLToConsole(session.URL()) } + } else { + browser.OpenURL(session.URL()) } - accessToken, err := handler.HandlePendingSession(ctx, sock, oauth2.GeneratePkceChallenge(), oauth2.GenerateState()) - if err != nil { - return err - } + errCh := make(chan error, 1) + go func() { + err := http.Serve(sock, handler) + if err != nil && !errors.Is(err, http.ErrServerClosed) { + errCh <- err + } + }() - // https://openid.net/specs/openid-connect-core-1_0.html#TokenResponse - idToken, ok := accessToken.Extra("id_token").(string) - if !ok { - return fmt.Errorf("id_token not found in token response") + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errCh: + return err + case err := <-session.Error: + return err + case token := <-session.Token: + // TODO Will panic if id_token not present + // TODO Verify token with OIDC provider + idToken := token.Extra("id_token").(string) + return config.SaveOAuthToken(token, idToken) } +} - return config.SaveOAuthToken(accessToken, idToken) +func (c LoginCommand) Run(globals *Globals, config *Config) error { + return c.RunContext(context.Background(), globals, config) } var ErrNoPortsAvailable = errors.New("no ports available") @@ -151,3 +123,8 @@ func openBrowserToURL(url string) error { slog.Debug("trying to open browser window", slog.String("url", url)) return browser.OpenURL(url) } + +func isPiped() bool { + fi, _ := os.Stdout.Stat() + return fi.Mode()&os.ModeCharDevice == 0 +} diff --git a/command/roles.go b/command/roles.go index 57a16962..237df219 100644 --- a/command/roles.go +++ b/command/roles.go @@ -1,43 +1,56 @@ package command import ( + "context" + "fmt" "strings" "github.com/RobotsAndPencils/go-saml" "github.com/riotgames/key-conjurer/oauth2" - "github.com/spf13/cobra" + "github.com/riotgames/key-conjurer/okta" ) -var rolesCmd = cobra.Command{ - Use: "roles ", - Short: "Returns the roles that you have access to in the given account.", - Args: cobra.ExactArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - config := ConfigFromCommand(cmd) - if HasTokenExpired(config.Tokens) { - return ErrTokensExpiredOrAbsent - } +type RolesCommand struct { + ApplicationID string `arg:""` +} - oidcDomain, _ := cmd.Flags().GetString(FlagOIDCDomain) - clientID, _ := cmd.Flags().GetString(FlagClientID) +func (r RolesCommand) Run(globals *Globals, config *Config) error { + return r.RunContext(context.Background(), globals, config) +} - var applicationID = args[0] - account, ok := config.FindAccount(applicationID) - if ok { - applicationID = account.ID - } +func (r RolesCommand) RunContext(ctx context.Context, globals *Globals, config *Config) error { + if HasTokenExpired(config.Tokens) { + return ErrTokensExpiredOrAbsent + } - samlResponse, _, err := oauth2.DiscoverConfigAndExchangeTokenForAssertion(cmd.Context(), config.Tokens.AccessToken, config.Tokens.IDToken, oidcDomain, clientID, applicationID) - if err != nil { - return err - } + account, ok := config.FindAccount(r.ApplicationID) + if ok { + r.ApplicationID = account.ID + } - for _, name := range listRoles(samlResponse) { - cmd.Println(name) - } + cfg, err := oauth2.DiscoverConfig(ctx, globals.OIDCDomain, globals.ClientID) + if err != nil { + return err + } - return nil - }, + samlResponse, _, err := okta.ExchangeTokenForAssertion( + ctx, + cfg, + config.Tokens.AccessToken, + config.Tokens.IDToken, + globals.OIDCDomain, + r.ApplicationID, + ) + + if err != nil { + return err + } + + for _, name := range listRoles(samlResponse) { + fmt.Println(name) + } + + return nil } type roleProviderPair struct { diff --git a/command/root.go b/command/root.go index b39fe717..89a6cb75 100644 --- a/command/root.go +++ b/command/root.go @@ -1,115 +1,128 @@ package command import ( - "context" "fmt" - "net/http" "os" "path/filepath" "runtime" - "time" - "github.com/coreos/go-oidc" - "github.com/mitchellh/go-homedir" - "github.com/spf13/cobra" + "github.com/alecthomas/kong" ) -var ( - FlagOIDCDomain = "oidc-domain" - FlagClientID = "client-id" - FlagConfigPath = "config" - FlagQuiet = "quiet" - FlagTimeout = "timeout" -) +type Globals struct { + OIDCDomain string `help:"The domain name of your OIDC server." hidden:"" env:"KEYCONJURER_OIDC_DOMAIN" default:"${oidc_domain}"` + ClientID string `help:"The client ID of your OIDC server." hidden:"" env:"KEYCONJURER_CLIENT_ID" default:"${client_id}"` + Quiet bool `help:"Tells the CLI to be quiet; stdout will not contain human-readable informational messages."` +} + +type CLI struct { + Globals -func init() { - rootCmd.PersistentFlags().String(FlagOIDCDomain, OIDCDomain, "The domain name of your OIDC server") - rootCmd.PersistentFlags().String(FlagClientID, ClientID, "The OAuth2 Client ID for the application registered with your OIDC server") - rootCmd.PersistentFlags().Int(FlagTimeout, 120, "the amount of time in seconds to wait for keyconjurer to respond") - rootCmd.PersistentFlags().String(FlagConfigPath, "~/.keyconjurerrc", "path to .keyconjurerrc file") - rootCmd.PersistentFlags().Bool(FlagQuiet, false, "tells the CLI to be quiet; stdout will not contain human-readable informational messages") - rootCmd.AddCommand(loginCmd) - rootCmd.AddCommand(accountsCmd) - rootCmd.AddCommand(getCmd) - rootCmd.AddCommand(setCmd) - rootCmd.AddCommand(&switchCmd) - rootCmd.AddCommand(&aliasCmd) - rootCmd.AddCommand(&unaliasCmd) - rootCmd.AddCommand(&rolesCmd) - rootCmd.SetVersionTemplate("{{.Version}}\n") - - rootCmd.PersistentFlags().MarkHidden(FlagOIDCDomain) - rootCmd.PersistentFlags().MarkHidden(FlagClientID) + Login LoginCommand `cmd:"" help:"Authenticate with KeyConjurer."` + Accounts AccountsCommand `cmd:"" help:"Display accounts."` + Get GetCommand `cmd:"" help:"Retrieve temporary cloud credentials."` + Switch SwitchCommand `cmd:"" help:"Switch between accounts."` + Roles RolesCommand `cmd:"" help:"Display roles for a specific account."` + Alias AliasCommand `cmd:"" help:"Create an alias for an account."` + Unalias UnaliasCommand `cmd:"" help:"Remove an alias."` + Set SetCommand `cmd:"" help:"Set config values."` + + Version kong.VersionFlag `help:"Show version information." short:"v" flag:""` + + Config Config `kong:"-"` } -// rootCmd represents the base command when called without any subcommands -var rootCmd = &cobra.Command{ - Use: "keyconjurer", - Version: fmt.Sprintf("keyconjurer-%s-%s %s (%s)", runtime.GOOS, runtime.GOARCH, Version, BuildTimestamp), - Short: "Retrieve temporary cloud credentials.", - Long: `KeyConjurer retrieves temporary credentials from Okta with the assistance of an optional API. +func (CLI) Help() string { + return `KeyConjurer retrieves temporary credentials from Okta with the assistance of an optional API. To get started run the following commands: keyconjurer login keyconjurer accounts - keyconjurer get -`, - FParseErrWhitelist: cobra.FParseErrWhitelist{ - UnknownFlags: true, - }, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - var config Config - // The error of this function call is only non-nil if the flag was not provided or is not a string. - configPath, _ := cmd.Flags().GetString(FlagConfigPath) - if expanded, err := homedir.Expand(configPath); err == nil { - configPath = expanded - } + keyconjurer get ` +} - file, err := EnsureConfigFileExists(configPath) +func getConfigPath() (string, error) { + cfgDir, err := os.UserConfigDir() + if err != nil { + // UserConfigDir() and UserHomeDir() do slightly different things. If UserConfigDir() fails, try UserHomeDir() + cfgDir, err = os.UserHomeDir() if err != nil { - return err + return "", err } + } - if err := config.Read(file); err != nil { - return err - } + return filepath.Join(cfgDir, "keyconjurer", "config.json"), nil +} - // We don't care about this being cancelled. - timeout, _ := cmd.Flags().GetInt(FlagTimeout) - nextCtx, _ := context.WithTimeout(cmd.Context(), time.Duration(timeout)*time.Second) - cmd.SetContext(ConfigContext(nextCtx, &config, configPath)) - return nil - }, - PersistentPostRunE: func(cmd *cobra.Command, _ []string) error { - config := ConfigFromCommand(cmd) - path := ConfigPathFromCommand(cmd) - if expanded, err := homedir.Expand(path); err == nil { - path = expanded - } +func (c *CLI) BeforeApply(ctx *kong.Context, trace *kong.Path) error { + cfgPath, err := getConfigPath() + if err != nil { + return err + } - // Do not use EnsureConfigFileExists here! - // EnsureConfigFileExists opens the file in append mode. - // If we open the file in append mode, we'll always append to the file. If we open the file in truncate mode before reading from the file, the content will be truncated _before we read from it_, which will cause a users configuration to be discarded every time we run the program. + file, err := EnsureConfigFileExists(cfgPath) + if err != nil { + return err + } - if err := os.MkdirAll(filepath.Dir(path), os.ModeDir|os.FileMode(0755)); err != nil { - return err - } + err = c.Config.Read(file) + if err != nil { + return err + } - file, err := os.Create(path) - if err != nil { - return fmt.Errorf("unable to create %s reason: %w", path, err) - } + // Make *Config available to all sub-commands. + // This must be &c.Config because c.Config is not a pointer. + ctx.Bind(&c.Config) + return nil +} + +func (c *CLI) AfterRun(ctx *kong.Context) error { + cfgPath, err := getConfigPath() + if err != nil { + return err + } + + // Do not use EnsureConfigFileExists here! EnsureConfigFileExists opens the file in append mode. + // If we open the file in append mode, we'll always append to the file. If we open the file in truncate mode before reading from the file, the content will be truncated _before we read from it_, which will cause a users configuration to be discarded every time we run the program. + if err := os.MkdirAll(filepath.Dir(cfgPath), os.ModeDir|os.FileMode(0755)); err != nil { + return err + } - defer file.Close() - return config.Write(file) - }, - SilenceErrors: true, - SilenceUsage: true, + file, err := os.Create(cfgPath) + if err != nil { + return fmt.Errorf("unable to create %s reason: %w", cfgPath, err) + } + + defer file.Close() + return c.Config.Write(file) +} + +func newKong(cli *CLI) (*kong.Kong, error) { + return kong.New( + cli, + kong.Name("keyconjurer"), + kong.Description("Retrieve temporary cloud credentials."), + kong.UsageOnError(), + kong.Vars{ + "client_id": ClientID, + "server_address": ServerAddress, + "oidc_domain": OIDCDomain, + "version": fmt.Sprintf("keyconjurer-%s-%s %s (%s)", runtime.GOOS, runtime.GOARCH, Version, BuildTimestamp), + }, + ) } -func Execute(ctx context.Context, args []string) error { - client := &http.Client{Transport: LogRoundTripper{http.DefaultTransport}} - ctx = oidc.ClientContext(ctx, client) - rootCmd.SetArgs(args) - return rootCmd.ExecuteContext(ctx) +func Execute(args []string) error { + var cli CLI + k, err := newKong(&cli) + if err != nil { + return err + } + + kongCtx, err := k.Parse(args) + if err != nil { + return err + } + + return kongCtx.Run(&cli.Globals) } diff --git a/command/root_test.go b/command/root_test.go deleted file mode 100644 index 3864e3d7..00000000 --- a/command/root_test.go +++ /dev/null @@ -1,84 +0,0 @@ -package command - -import ( - "bytes" - "fmt" - "runtime" - "testing" - - "github.com/spf13/cobra" - "github.com/stretchr/testify/assert" -) - -// Required to reset Cobra state between Test runs. -// If you add new tests that change state, you may need to -// add code here to reset the side effects of other tests -func resetCobra(cmd *cobra.Command) { - cmd.Flags().Set("help", "false") - cmd.Flags().Set("version", "false") -} - -func execute(cmd *cobra.Command, args ...string) (string, error) { - var buf bytes.Buffer - cmd.SetArgs(args) - cmd.SetOutput(&buf) - err := cmd.Execute() - return buf.String(), err -} - -func TestVersionFlag(t *testing.T) { - t.Cleanup(func() { - resetCobra(rootCmd) - }) - - output, err := execute(rootCmd, "--version") - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - expected := fmt.Sprintf("keyconjurer-%s-%s TBD (BuildTimestamp is not set)\n", runtime.GOOS, runtime.GOARCH) - assert.Equal(t, output, expected) -} - -func TestVersionShortFlag(t *testing.T) { - t.Cleanup(func() { - resetCobra(rootCmd) - }) - - output, err := execute(rootCmd, "-v") - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - expected := fmt.Sprintf("keyconjurer-%s-%s TBD (BuildTimestamp is not set)\n", runtime.GOOS, runtime.GOARCH) - assert.Equal(t, output, expected) -} - -func TestHelpLongInvalidArgs(t *testing.T) { - t.Cleanup(func() { - resetCobra(rootCmd) - }) - output, err := execute(rootCmd, "get", "-s") - if err != nil { - if err.Error() != "unknown shorthand flag: 's' in -s" { - t.Errorf("Unexpected error: %v", err) - } - } else { - t.Errorf("Unexpected non-error, output=: %v", output) - } -} - -func TestInvalidCommand(t *testing.T) { - t.Cleanup(func() { - resetCobra(rootCmd) - }) - - output, err := execute(rootCmd, "badcommand") - if err != nil { - if err.Error() != "unknown command \"badcommand\" for \"keyconjurer\"" { - t.Errorf("Unexpected error: %v", err) - } - } else { - t.Errorf("Unexpected non-error, output=: %v", output) - } -} diff --git a/command/set.go b/command/set.go index d071b0c4..f877b072 100644 --- a/command/set.go +++ b/command/set.go @@ -3,51 +3,37 @@ package command import ( "fmt" "strconv" - - "github.com/spf13/cobra" ) -func init() { - setCmd.AddCommand(setTTLCmd) - setCmd.AddCommand(setTimeRemainingCmd) +type SetCommand struct { + TTL TTLCommand `cmd:"" name:"ttl" help:"Sets ttl value in number of hours."` + TimeRemaining TimeRemainingCommand `cmd:"" name:"time-remaining" help:"Sets time remaining value in number of minutes."` } -var setCmd = &cobra.Command{ - Use: "set", - Short: "Sets config values.", - Long: "Sets config values.", +type TTLCommand struct { + TTL string `arg:"" help:"The ttl value in number of hours." placeholder:"hours"` } -var setTTLCmd = &cobra.Command{ - Use: "ttl ", - Short: "Sets ttl value in number of hours.", - Long: "Sets ttl value in number of hours.", - Args: cobra.ExactArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - config := ConfigFromCommand(cmd) - ttl, err := strconv.ParseUint(args[0], 10, 32) - if err != nil { - return fmt.Errorf("unable to parse value %s", args[0]) - } - - config.TTL = uint(ttl) - return nil - }, +func (t TTLCommand) Run(globals *Globals, config *Config) error { + ttl, err := strconv.ParseUint(t.TTL, 10, 32) + if err != nil { + return fmt.Errorf("unable to parse value %s", t.TTL) + } + + config.TTL = uint(ttl) + return nil } -var setTimeRemainingCmd = &cobra.Command{ - Use: "time-remaining ", - Short: "Sets time remaining value in number of minutes.", - Long: "Sets time remaining value in number of minutes. Using minutes is an artifact from when keys could only live for 1 hour.", - Args: cobra.ExactArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - config := ConfigFromCommand(cmd) - timeRemaining, err := strconv.ParseUint(args[0], 10, 32) - if err != nil { - return fmt.Errorf("unable to parse value %s", args[0]) - } - - config.TimeRemaining = uint(timeRemaining) - return nil - }, +type TimeRemainingCommand struct { + TimeRemaining string `arg:"" help:"The time remaining value in number of minutes." placeholder:"minutes"` +} + +func (t TimeRemainingCommand) Run(globals *Globals, config *Config) error { + timeRemaining, err := strconv.ParseUint(t.TimeRemaining, 10, 32) + if err != nil { + return fmt.Errorf("unable to parse value %s", t.TimeRemaining) + } + + config.TimeRemaining = uint(timeRemaining) + return nil } diff --git a/command/switch.go b/command/switch.go index 35c0e08b..4ff0904c 100644 --- a/command/switch.go +++ b/command/switch.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "os" - "slices" "strings" "time" @@ -12,86 +11,27 @@ import ( "github.com/aws/aws-sdk-go-v2/aws/arn" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/sts" - "github.com/spf13/cobra" - "github.com/spf13/pflag" ) -var ( - FlagRoleSessionName = "role-session-name" - FlagOutputType = "out" - FlagShellType = "shell" - FlagAWSCLIPath = "awscli" -) - -func init() { - switchCmd.Flags().String(FlagRoleSessionName, "KeyConjurer-AssumeRole", "the name of the role session name that will show up in CloudTrail logs") - switchCmd.Flags().StringP(FlagOutputType, "o", outputTypeEnvironmentVariable, "Format to save new credentials in. Supported outputs: env, awscli") - switchCmd.Flags().String(FlagShellType, shellTypeInfer, "If output type is env, determines which format to output credentials in - by default, the format is inferred based on the execution environment. WSL users may wish to overwrite this to `bash`") - switchCmd.Flags().String(FlagAWSCLIPath, "~/.aws/", "Path for directory used by the aws-cli tool. Default is \"~/.aws\".") -} - -var switchCmd = cobra.Command{ - Use: "switch ", - Short: "Switch from the current AWS account into the one with the given Account ID.", - Long: `Attempt to AssumeRole into the given AWS with the current credentials. You only need to use this if you are a power user or network engineer with access to many accounts. - -This is used when a "bastion" account exists which users initially authenticate into and then pivot from that account into other accounts. - -This command will fail if you do not have active Cloud credentials. -`, - Example: "keyconjurer switch 123456798", - Args: cobra.ExactArgs(1), - Aliases: []string{"switch-account"}, - RunE: func(cmd *cobra.Command, args []string) error { - var switchCmd SwitchCommand - if err := switchCmd.Parse(cmd.Flags(), args); err != nil { - return err - } - - if err := switchCmd.Validate(); err != nil { - return err - } - - return switchCmd.Execute(cmd.Context()) - }, -} - type SwitchCommand struct { - OutputType string - ShellType string - AWSCLIPath string - RoleSessionName string - AccountID string + OutputType string `help:"Format to save new credentials in." default:"env" enum:"env,awscli" short:"o" default:"env" name:"out"` + AWSCLIPath string `help:"Path to the AWS CLI configuration directory." default:"~/.aws/" name:"awscli"` + ShellType string `name:"shell" help:"If output type is env, determines which format to output credentials in. WSL users may wish to overwrite this to \"bash\"." default:"infer" enum:"infer,basic,powershell,bash"` + SessionName string `help:"The name of the role session name that will show up in CloudTrail logs." default:"KeyConjurer-AssumeRole"` + AccountID string `arg:"" placeholder:"account-id"` } -func (s *SwitchCommand) Parse(flags *pflag.FlagSet, args []string) error { - s.OutputType, _ = flags.GetString(FlagOutputType) - s.ShellType, _ = flags.GetString(FlagShellType) - s.AWSCLIPath, _ = flags.GetString(FlagAWSCLIPath) - s.RoleSessionName, _ = flags.GetString(FlagRoleSessionName) - if len(args) == 0 { - return fmt.Errorf("account-id is required") - } - - s.AccountID = args[0] - return nil -} - -func (s SwitchCommand) Validate() error { - if !slices.Contains(permittedOutputTypes, s.OutputType) { - return ValueError{Value: s.OutputType, ValidValues: permittedOutputTypes} - } +func (SwitchCommand) Help() string { + return `Attempt to AssumeRole into the given AWS with the current credentials. You only need to use this if you are a power user or network engineer with access to many accounts. - if !slices.Contains(permittedShellTypes, s.ShellType) { - return ValueError{Value: s.ShellType, ValidValues: permittedShellTypes} - } +This is used when a "bastion" account exists which users initially authenticate into and then pivot from that account into other accounts. - return nil +This command will fail if you do not have active Cloud credentials.` } -func (s SwitchCommand) Execute(ctx context.Context) error { +func (s SwitchCommand) RunContext(ctx context.Context) error { // We could read the environment variable for the assumed role ARN, but it might be expired which isn't very useful to the user. - creds, err := getAWSCredentials(ctx, s.AccountID, s.RoleSessionName) + creds, err := getAWSCredentials(ctx, s.AccountID, s.SessionName) if err != nil { // If this failed, either there was a network error or the user is not authorized to assume into this role // This can happen if the user is not authenticated using the Bastion instance. @@ -99,10 +39,10 @@ func (s SwitchCommand) Execute(ctx context.Context) error { } switch s.OutputType { - case outputTypeEnvironmentVariable: + case "env": creds.WriteFormat(os.Stdout, s.ShellType) return nil - case outputTypeAWSCredentialsFile: + case "aws": acc := Account{ID: s.AccountID, Name: s.AccountID} newCliEntry := NewCloudCliEntry(creds, &acc) return SaveCloudCredentialInCLI(s.AWSCLIPath, newCliEntry) @@ -111,6 +51,10 @@ func (s SwitchCommand) Execute(ctx context.Context) error { } } +func (s SwitchCommand) Run() error { + return s.RunContext(context.Background()) +} + func getAWSCredentials(ctx context.Context, accountID, roleSessionName string) (creds CloudCredentials, err error) { cfg, err := config.LoadDefaultConfig(ctx) if err != nil { diff --git a/command/unalias.go b/command/unalias.go deleted file mode 100644 index aef1ea05..00000000 --- a/command/unalias.go +++ /dev/null @@ -1,15 +0,0 @@ -package command - -import ( - "github.com/spf13/cobra" -) - -var unaliasCmd = cobra.Command{ - Use: "unalias ", - Short: "Remove alias from account.", - Args: cobra.ExactArgs(1), - Example: "keyconjurer unalias bar", - Run: func(cmd *cobra.Command, args []string) { - config := ConfigFromCommand(cmd) - config.Unalias(args[0]) - }} diff --git a/go.mod b/go.mod index a3748bb2..2f8cc266 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,7 @@ module github.com/riotgames/key-conjurer require ( github.com/RobotsAndPencils/go-saml v0.0.0-20170520135329-fb13cb52a46b + github.com/alecthomas/kong v1.4.0 github.com/aws/aws-lambda-go v1.47.0 github.com/aws/aws-sdk-go-v2 v1.32.4 github.com/aws/aws-sdk-go-v2/config v1.28.3 @@ -14,11 +15,9 @@ require ( github.com/mitchellh/go-ps v1.0.0 github.com/okta/okta-sdk-golang/v2 v2.2.1 github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 - github.com/spf13/cobra v1.7.0 - github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.8.4 golang.org/x/net v0.25.0 - golang.org/x/oauth2 v0.6.0 + golang.org/x/oauth2 v0.24.0 ) require ( @@ -34,7 +33,6 @@ require ( github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-jose/go-jose/v4 v4.0.1 // indirect - github.com/golang/protobuf v1.5.2 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect @@ -44,7 +42,6 @@ require ( github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 // indirect github.com/hashicorp/go-sockaddr v1.0.2 // indirect github.com/hashicorp/hcl v1.0.0 // indirect - github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 // indirect github.com/kelseyhightower/envconfig v1.4.0 // indirect github.com/kr/text v0.2.0 // indirect @@ -60,8 +57,6 @@ require ( golang.org/x/sys v0.20.0 // indirect golang.org/x/text v0.15.0 // indirect golang.org/x/time v0.0.0-20220922220347-f3bd1da661af // indirect - google.golang.org/appengine v1.6.7 // indirect - google.golang.org/protobuf v1.28.0 // indirect gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect gopkg.in/ini.v1 v1.42.0 // indirect gopkg.in/square/go-jose.v2 v2.5.1 // indirect diff --git a/go.sum b/go.sum index 90f06203..65318709 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,11 @@ github.com/RobotsAndPencils/go-saml v0.0.0-20170520135329-fb13cb52a46b h1:EgJ6N2S0h1WfFIjU5/VVHWbMSVYXAluop97Qxpr/lfQ= github.com/RobotsAndPencils/go-saml v0.0.0-20170520135329-fb13cb52a46b/go.mod h1:3SAoF0F5EbcOuBD5WT9nYkbIJieBS84cUQXADbXeBsU= +github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0= +github.com/alecthomas/assert/v2 v2.11.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k= +github.com/alecthomas/kong v1.4.0 h1:UL7tzGMnnY0YRMMvJyITIRX1EpO6RbBRZDNcCevy3HA= +github.com/alecthomas/kong v1.4.0/go.mod h1:p2vqieVMeTAnaC83txKtXe8FLke2X07aruPWXyMPQrU= +github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc= +github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= github.com/aws/aws-lambda-go v1.47.0 h1:0H8s0vumYx/YKs4sE7YM0ktwL2eWse+kfopsRI1sXVI= github.com/aws/aws-lambda-go v1.47.0/go.mod h1:dpMpZgvWx5vuQJfBt0zqBha60q7Dd7RfgJv23DymV8A= @@ -35,7 +41,6 @@ github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK3 github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/coreos/go-oidc v2.2.1+incompatible h1:mh48q/BqXqgjVHpy2ZY7WnWAbenxRjsz9N1i1YxjHAk= github.com/coreos/go-oidc v2.2.1+incompatible/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHoZ1nMCKZlZ9V6mm3/LKc= -github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -49,12 +54,7 @@ github.com/go-jose/go-jose/v4 v4.0.1 h1:QVEPDE3OluqXBQZDcnNvQrInro2h0e4eqNbnZSWq github.com/go-jose/go-jose/v4 v4.0.1/go.mod h1:WVf9LFMHh/QVrmqrOfqun0C45tMe3RoiKJMPvgWwLfY= github.com/go-test/deep v1.0.2 h1:onZX1rnHT3Wv6cqNgYyFOOlgVKJrksuCMCRvJStbMYw= github.com/go-test/deep v1.0.2/go.mod h1:wGDj63lr65AM2AQyKZd/NYHGb0R+1RLqB8NKt3aSFNA= -github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= -github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8= @@ -84,8 +84,8 @@ github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/hashicorp/vault/api v1.15.0 h1:O24FYQCWwhwKnF7CuSqP30S51rTV7vz1iACXE/pj5DA= github.com/hashicorp/vault/api v1.15.0/go.mod h1:+5YTO09JGn0u+b6ySD/LLVf8WkJCPLAL2Vkmrn2+CM8= -github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= -github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= +github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/jarcoal/httpmock v1.0.6/go.mod h1:ATjnClrvW/3tijVmpL/va5Z3aAyGvqU3gCT8nX0Txik= github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= @@ -128,7 +128,6 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI= github.com/pquerna/cachecontrol v0.2.0 h1:vBXSNuE5MYP9IJ5kjsdo8uq+w41jSPgvba2DEnkRx9k= github.com/pquerna/cachecontrol v0.2.0/go.mod h1:NrUG3Z7Rdu85UNR3vm7SOsl1nFIeSiQnrHV5K9mBcUI= -github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/ryanuber/columnize v2.1.0+incompatible/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkBk= github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc= @@ -136,10 +135,6 @@ github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykE github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v0.0.0-20190330032615-68dc04aab96a h1:pa8hGb/2YqsZKovtsgrwcDH1RZhVbTKCjLp47XpqCDs= github.com/smartystreets/goconvey v0.0.0-20190330032615-68dc04aab96a/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= -github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= -github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= -github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= -github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -151,11 +146,10 @@ golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= -golang.org/x/oauth2 v0.6.0 h1:Lh8GPgSKBfWSwFvtuWOfeI3aAAnbXTSutYxJiOJFgIw= -golang.org/x/oauth2 v0.6.0/go.mod h1:ycmewcwgD4Rpr3eZJLSB4Kyyljb3qDh40vJ8STE5HKw= +golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE= +golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -163,20 +157,12 @@ golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/time v0.0.0-20220922220347-f3bd1da661af h1:Yx9k8YCG3dvF87UAn2tu2HQLf2dt/eR1bXxpLMWeH+Y= golang.org/x/time v0.0.0-20220922220347-f3bd1da661af/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= -google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= -google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= -google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= diff --git a/internal/api/settings.go b/internal/api/settings.go index c929f8a3..6c7971fd 100644 --- a/internal/api/settings.go +++ b/internal/api/settings.go @@ -14,14 +14,12 @@ type Settings struct { OktaToken string `json:"oktaToken"` } -var SettingsProviders = map[string]SettingsProvider{} - -func init() { - SettingsProviders["env"] = SettingsProviderFunc(RetrieveSettingsFromEnv) - SettingsProviders["vault"] = VaultRetriever{ +var SettingsProviders = map[string]SettingsProvider{ + "env": SettingsProviderFunc(RetrieveSettingsFromEnv), + "vault": VaultRetriever{ SecretMountPath: os.Getenv("KC_SECRET_MOUNT_PATH"), SecretPath: os.Getenv("KC_SECRET_PATH"), - } + }, } type SettingsProvider interface { diff --git a/main.go b/main.go index 2dfa7880..34fd985d 100644 --- a/main.go +++ b/main.go @@ -1,7 +1,6 @@ package main import ( - "context" "errors" "fmt" "os" @@ -11,23 +10,9 @@ import ( "log/slog" "github.com/riotgames/key-conjurer/command" - "github.com/spf13/cobra" ) -const ( - // WSAEACCES is the Windows error code for attempting to access a socket that you don't have permission to access. - // - // This commonly occurs if the socket is in use or was not closed correctly, and can be resolved by restarting the hns service. - WSAEACCES = 10013 -) - -// IsWindowsPortAccessError determines if the given error is the error WSAEACCES. -func IsWindowsPortAccessError(err error) bool { - var syscallErr *syscall.Errno - return errors.As(err, &syscallErr) && *syscallErr == WSAEACCES -} - -func init() { +func main() { var opts slog.HandlerOptions if os.Getenv("DEBUG") == "1" { opts.Level = slog.LevelDebug @@ -35,24 +20,21 @@ func init() { handler := slog.NewTextHandler(os.Stdout, &opts) slog.SetDefault(slog.New(handler)) -} -func main() { args := os.Args[1:] if flag, ok := os.LookupEnv("KEYCONJURERFLAGS"); ok { args = append(args, strings.Split(flag, " ")...) } - err := command.Execute(context.Background(), args) - if IsWindowsPortAccessError(err) { + err := command.Execute(args) + if isWindowsPortAccessError(err) { fmt.Fprintf(os.Stderr, "Encountered an issue when opening the port for KeyConjurer: %s\n", err) fmt.Fprintln(os.Stderr, "Consider running `net stop hns` and then `net start hns`") os.Exit(command.ExitCodeConnectivityError) } if err != nil { - cobra.CheckErr(err) - + fmt.Fprintln(os.Stderr, "Error:", err) errorCode, ok := command.GetExitCode(err) if !ok { errorCode = command.ExitCodeUnknownError @@ -60,3 +42,16 @@ func main() { os.Exit(errorCode) } } + +const ( + // wsaeacces is the Windows error code for attempting to access a socket that you don't have permission to access. + // + // This commonly occurs if the socket is in use or was not closed correctly, and can be resolved by restarting the hns service. + wsaeacces = 10013 +) + +// isWindowsPortAccessError determines if the given error is the error wsaeacces. +func isWindowsPortAccessError(err error) bool { + var syscallErr *syscall.Errno + return errors.As(err, &syscallErr) && *syscallErr == wsaeacces +} diff --git a/oauth2/oauth2.go b/oauth2/oauth2.go index 6f232275..3c951ef6 100644 --- a/oauth2/oauth2.go +++ b/oauth2/oauth2.go @@ -3,22 +3,15 @@ package oauth2 import ( "context" "crypto/rand" - "crypto/sha256" "encoding/base64" - "errors" "fmt" - "net" "net/http" - "strings" "sync" - "github.com/RobotsAndPencils/go-saml" "github.com/coreos/go-oidc" "golang.org/x/oauth2" ) -var ErrNoSAMLAssertion = errors.New("no saml assertion") - // stateBufSize is the size of the buffer used to generate the state parameter. // 43 is a magic number - It generates states that are not too short or long for Okta's validation. const stateBufSize = 43 @@ -32,180 +25,111 @@ func DiscoverConfig(ctx context.Context, domain, clientID string) (*oauth2.Confi cfg := oauth2.Config{ ClientID: clientID, Endpoint: provider.Endpoint(), - Scopes: []string{"openid", "profile", "okta.apps.read", "okta.apps.sso"}, + Scopes: []string{"openid", "profile", "okta.apps.sso"}, } return &cfg, nil } -// OAuth2CallbackState encapsulates all of the information from an oauth2 callback. -// -// To retrieve the Code from the struct, you must use the Verify(string) function. -type OAuth2CallbackState struct { +type callbackState struct { code string state string errorMessage string errorDescription string } -// FromRequest parses the given http.Request and populates the OAuth2CallbackState with those values. -func (o *OAuth2CallbackState) FromRequest(r *http.Request) { - o.errorMessage = r.FormValue("error") - o.errorDescription = r.FormValue("error_description") - o.state = r.FormValue("state") - o.code = r.FormValue("code") -} - -// Verify safely compares the given state with the state from the OAuth2 callback. -// -// If they match, the code is returned, with a nil value. Otherwise, an empty string and an error is returned. -func (o OAuth2CallbackState) Verify(expectedState string) (string, error) { - if o.errorMessage != "" { - return "", OAuth2Error{Reason: o.errorMessage, Description: o.errorDescription} - } - - if strings.Compare(o.state, expectedState) != 0 { - return "", OAuth2Error{Reason: "invalid_state", Description: "state mismatch"} - } - - return o.code, nil -} - -// OAuth2CallbackHandler returns a http.Handler, channel and function triple. -// -// The http handler will accept exactly one request, which it will assume is an OAuth2 callback, parse it into an OAuth2CallbackState and then provide it to the given channel. Subsequent requests will be silently ignored. -// -// The function may be called to ensure that the channel is closed. The channel is closed when a request is received. In general, it is a good idea to ensure this function is called in a defer() block. -func OAuth2CallbackHandler() (http.Handler, <-chan OAuth2CallbackState, func()) { - // TODO: It is possible for the caller to close a panic() if they execute the function in the triplet while the handler has not yet received a request. - // That caller is us, so I don't care that much, but that probably indicates that this design is smelly. - // - // We should look at the Go SDK to see how they handle similar cases - channels that are not bound by a timer, or similar. - - ch := make(chan OAuth2CallbackState, 1) - var reqHandle, closeHandle sync.Once - closeFn := func() { - closeHandle.Do(func() { - close(ch) - }) - } - - fn := func(w http.ResponseWriter, r *http.Request) { - // This can sometimes be called multiple times, depending on the browser. - // We will simply ignore any other requests and only serve the first. - reqHandle.Do(func() { - var state OAuth2CallbackState - state.FromRequest(r) - ch <- state - closeFn() - }) - - // We still want to provide feedback to the end-user. - fmt.Fprintln(w, "You may close this window now.") - } - - return http.HandlerFunc(fn), ch, closeFn -} - -type OAuth2Error struct { - Reason string - Description string -} - -func (e OAuth2Error) Error() string { - return fmt.Sprintf("oauth2 error: %s (%s)", e.Description, e.Reason) +func parseOAuth2CallbackState(r *http.Request, info *callbackState) error { + err := r.ParseForm() + info.errorMessage = r.FormValue("error") + info.errorDescription = r.FormValue("error_description") + info.state = r.FormValue("state") + info.code = r.FormValue("code") + return err } -func GeneratePkceChallenge() PkceChallenge { - codeVerifierBuf := make([]byte, stateBufSize) - rand.Read(codeVerifierBuf) - codeVerifier := base64.RawURLEncoding.EncodeToString(codeVerifierBuf) - codeChallengeHash := sha256.Sum256([]byte(codeVerifier)) - codeChallenge := base64.RawURLEncoding.EncodeToString(codeChallengeHash[:]) - return PkceChallenge{Verifier: codeVerifier, Challenge: codeChallenge} -} - -func GenerateState() string { +func generateState() string { stateBuf := make([]byte, stateBufSize) rand.Read(stateBuf) return base64.URLEncoding.EncodeToString(stateBuf) } -type PkceChallenge struct { - Challenge string - Verifier string -} +type Session struct { + url string + state string + verifier string -type RedirectionFlowHandler struct { - Config *oauth2.Config - OnDisplayURL func(url string) error + Token chan *oauth2.Token + Error chan error } -func (r RedirectionFlowHandler) HandlePendingSession(ctx context.Context, listener net.Listener, challenge PkceChallenge, state string) (*oauth2.Token, error) { - if r.OnDisplayURL == nil { - panic("OnDisplayURL must be set") - } - - url := r.Config.AuthCodeURL(state, - oauth2.SetAuthURLParam("code_challenge_method", "S256"), - oauth2.SetAuthURLParam("code_challenge", challenge.Challenge), - ) - - callbackHandler, ch, cancel := OAuth2CallbackHandler() - // TODO: This error probably should not be ignored if it is not http.ErrServerClosed - go http.Serve(listener, callbackHandler) - defer cancel() - - if err := r.OnDisplayURL(url); err != nil { - // This is unlikely to ever happen - return nil, fmt.Errorf("failed to display link: %w", err) - } +func (s Session) URL() string { + return s.url +} - select { - case info := <-ch: - code, err := info.Verify(state) - if err != nil { - return nil, fmt.Errorf("failed to get authorization code: %w", err) - } - return r.Config.Exchange(ctx, code, oauth2.SetAuthURLParam("code_verifier", challenge.Verifier)) - case <-ctx.Done(): - return nil, ctx.Err() +// NewAuthorizationCodeHandler creates a new AuthorizationCodeHandler. +func NewAuthorizationCodeHandler(config *oauth2.Config) *AuthorizationCodeHandler { + return &AuthorizationCodeHandler{ + config: config, + sessions: make(map[string]Session), } } -func DiscoverConfigAndExchangeTokenForAssertion(ctx context.Context, accessToken, idToken, oidcDomain, clientID, applicationID string) (*saml.Response, string, error) { - oauthCfg, err := DiscoverConfig(ctx, oidcDomain, clientID) - if err != nil { - return nil, "", Error{Message: "could not discover oauth2 config", InnerError: err} +// AuthorizationCodeHandler is an http.Handler that handles the OAuth2 authorization code flow. +// +// It is intended to be used by CLIs that need to authenticate with an OAuth2 provider. +// +// Sessions can be created using NewSession, and those sessions can be used to retrieve the OAuth2 token. +type AuthorizationCodeHandler struct { + config *oauth2.Config + sessions map[string]Session + mu sync.Mutex +} + +func (a *AuthorizationCodeHandler) NewSession() Session { + state := generateState() + verifier := oauth2.GenerateVerifier() + url := a.config.AuthCodeURL(state, oauth2.S256ChallengeOption(verifier)) + // A channel capacity of 1 is used to prevent requests from blocking if they are not actively being awaited on. + s := Session{verifier: verifier, state: state, url: url, Token: make(chan *oauth2.Token, 1)} + a.mu.Lock() + defer a.mu.Unlock() + a.sessions[state] = s + return s +} + +func (a *AuthorizationCodeHandler) removeSessionIfExists(state string) (Session, bool) { + a.mu.Lock() + defer a.mu.Unlock() + s, ok := a.sessions[state] + if ok { + delete(a.sessions, state) } + return s, ok +} - tok, err := exchangeAccessTokenForWebSSOToken(ctx, oauthCfg, accessToken, idToken, applicationID) - if err != nil { - return nil, "", Error{Message: "error exchanging token", InnerError: err} +func (a *AuthorizationCodeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + var st callbackState + if err := parseOAuth2CallbackState(r, &st); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return } - assertionBytes, err := exchangeWebSSOTokenForSAMLAssertion(ctx, oidcDomain, tok) - if err != nil { - return nil, "", Error{Message: "failed to fetch SAML assertion", InnerError: err} + session, ok := a.removeSessionIfExists(st.state) + if !ok { + http.Error(w, "no session", http.StatusBadRequest) + return } - response, err := saml.ParseEncodedResponse(string(assertionBytes)) + token, err := a.config.Exchange(r.Context(), st.code, oauth2.VerifierOption(session.verifier)) if err != nil { - return nil, "", Error{Message: "failed to parse SAML response", InnerError: err} + session.Error <- err + http.Error(w, err.Error(), http.StatusInternalServerError) + return } - return response, string(assertionBytes), nil -} - -type Error struct { - InnerError error - Message string -} - -func (o Error) Unwrap() error { - return o.InnerError -} - -func (o Error) Error() string { - return o.Message + // Make sure to respond to the user right away. If we don't, + // the server may be closed before a response can be sent. + fmt.Fprintln(w, "You may close this window now.") + session.Token <- token + close(session.Token) } diff --git a/oauth2/oauth2_test.go b/oauth2/oauth2_test.go index f5d2ef2d..ea668618 100644 --- a/oauth2/oauth2_test.go +++ b/oauth2/oauth2_test.go @@ -6,9 +6,37 @@ import ( "net/url" "testing" - "github.com/stretchr/testify/assert" + "golang.org/x/oauth2" ) +var cfg = &oauth2.Config{ + ClientID: "client-id", + Endpoint: oauth2.Endpoint{ + TokenURL: "http://localhost/oauth2/token", + }, +} + +type roundTripperFunc func(req *http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +var client = http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if req.URL.Path == "/oauth2/token" { + return &http.Response{ + StatusCode: http.StatusOK, + Body: nil, + }, nil + } + return &http.Response{ + StatusCode: http.StatusNotFound, + Body: nil, + }, nil + }), +} + func sendOAuth2CallbackRequest(handler http.Handler, values url.Values) { uri := url.URL{ Scheme: "http", @@ -22,59 +50,9 @@ func sendOAuth2CallbackRequest(handler http.Handler, values url.Values) { handler.ServeHTTP(w, req) } -func Test_OAuth2CallbackHandler_YieldsCorrectlyFormattedState(t *testing.T) { - handler, ch, cancel := OAuth2CallbackHandler() - t.Cleanup(func() { - cancel() - }) - - expectedState := "state goes here" - expectedCode := "code goes here" - - go sendOAuth2CallbackRequest(handler, url.Values{ - "code": []string{expectedCode}, - "state": []string{expectedState}, - }) - - callbackState := <-ch - code, err := callbackState.Verify(expectedState) - assert.NoError(t, err) - assert.Equal(t, expectedCode, code) -} - -func Test_OAuth2CallbackState_VerifyWorksCorrectly(t *testing.T) { - t.Run("happy path", func(t *testing.T) { - expectedState := "state goes here" - expectedCode := "code goes here" - callbackState := OAuth2CallbackState{ - code: expectedCode, - state: expectedState, - } - code, err := callbackState.Verify(expectedState) - assert.NoError(t, err) - assert.Equal(t, expectedCode, code) - }) - - t.Run("unhappy path", func(t *testing.T) { - expectedState := "state goes here" - expectedCode := "code goes here" - callbackState := OAuth2CallbackState{ - code: expectedCode, - state: expectedState, - } - _, err := callbackState.Verify("mismatching state") - var oauthErr OAuth2Error - assert.ErrorAs(t, err, &oauthErr) - assert.Equal(t, "invalid_state", oauthErr.Reason) - }) -} - -// Test_OAuth2Listener_MultipleRequestsDoesNotCausePanic prevents an issue where OAuth2Listener would send a request to a closed channel -func Test_OAuth2Listener_MultipleRequestsDoesNotCausePanic(t *testing.T) { - handler, ch, cancel := OAuth2CallbackHandler() - t.Cleanup(func() { - cancel() - }) +// Test_AuthorizationCodeHandler_IsReentrant prevents an issue where AuthorizationCodeHandler would send a request to a closed channel +func Test_AuthorizationCodeHandler_IsReentrant(t *testing.T) { + handler := NewAuthorizationCodeHandler(cfg) go sendOAuth2CallbackRequest(handler, url.Values{ // We send empty values because we don't care about processing in this test @@ -82,13 +60,10 @@ func Test_OAuth2Listener_MultipleRequestsDoesNotCausePanic(t *testing.T) { "state": []string{""}, }) - // We drain the channel of the first request so the handler completes. - // Without this step, we would get 'stuck' in the sync.Once(). - <-ch - // We send this request synchronously to ensure that any panics are caught during the test. sendOAuth2CallbackRequest(handler, url.Values{ "code": []string{"not the expected code and should be discarded"}, "state": []string{"not the expected state and should be discarded"}, }) + // If we reach here with no panics, it should pass } diff --git a/oauth2/html.go b/okta/html.go similarity index 99% rename from oauth2/html.go rename to okta/html.go index d01681e6..7d94df1b 100644 --- a/oauth2/html.go +++ b/okta/html.go @@ -1,4 +1,4 @@ -package oauth2 +package okta import ( "errors" diff --git a/oauth2/html_test.go b/okta/html_test.go similarity index 99% rename from oauth2/html_test.go rename to okta/html_test.go index 391fd65c..961a3b57 100644 --- a/oauth2/html_test.go +++ b/okta/html_test.go @@ -1,4 +1,4 @@ -package oauth2 +package okta import ( "strings" diff --git a/oauth2/websso.go b/okta/websso.go similarity index 68% rename from oauth2/websso.go rename to okta/websso.go index 379a0c10..2dfd3970 100644 --- a/oauth2/websso.go +++ b/okta/websso.go @@ -1,4 +1,4 @@ -package oauth2 +package okta import ( "context" @@ -7,10 +7,13 @@ import ( "net/http" "net/url" + "github.com/RobotsAndPencils/go-saml" "golang.org/x/net/html" "golang.org/x/oauth2" ) +var ErrNoSAMLAssertion = errors.New("no saml assertion") + // exchangeAccessTokenForWebSSOToken exchanges an OAuth2 token for an Okta Web SSO token. // // An Okta Web SSO token is a non-standard authorization token for Okta's Web SSO endpoint. @@ -30,17 +33,12 @@ func exchangeAccessTokenForWebSSOToken(ctx context.Context, oauthCfg *oauth2.Con // exchangeWebSSOTokenForSAMLAssertion is an Okta-specific API which exchanges an Okta Web SSO token, which is obtained by exchanging an OAuth2 token using the RFC8693 Token Exchange Flow, for a SAML assertion. // // It is not standards compliant, but is used by Okta in their own okta-aws-cli. -func exchangeWebSSOTokenForSAMLAssertion(ctx context.Context, issuer string, token *oauth2.Token) ([]byte, error) { +func exchangeWebSSOTokenForSAMLAssertion(ctx context.Context, client *http.Client, issuer string, token *oauth2.Token) ([]byte, error) { data := url.Values{"token": {token.AccessToken}} uri := fmt.Sprintf("%s/login/token/sso?%s", issuer, data.Encode()) req, _ := http.NewRequestWithContext(ctx, "GET", uri, nil) req.Header.Add("Accept", "text/html") - client := http.DefaultClient - if val, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok { - client = val - } - resp, err := client.Do(req) if err != nil { return nil, err @@ -63,3 +61,22 @@ func exchangeWebSSOTokenForSAMLAssertion(ctx context.Context, issuer string, tok return []byte(saml), nil } + +func ExchangeTokenForAssertion(ctx context.Context, cfg *oauth2.Config, accessToken, idToken, oidcDomain, applicationID string) (*saml.Response, string, error) { + tok, err := exchangeAccessTokenForWebSSOToken(ctx, cfg, accessToken, idToken, applicationID) + if err != nil { + return nil, "", fmt.Errorf("error exchanging token: %w", err) + } + + assertionBytes, err := exchangeWebSSOTokenForSAMLAssertion(ctx, cfg.Client(ctx, nil), oidcDomain, tok) + if err != nil { + return nil, "", fmt.Errorf("failed to fetch SAML assertion: %w", err) + } + + response, err := saml.ParseEncodedResponse(string(assertionBytes)) + if err != nil { + return nil, "", fmt.Errorf("failed to parse SAML response: %w", err) + } + + return response, string(assertionBytes), nil +}