diff --git a/command/login.go b/command/login.go index 8febe50f..12d7d02f 100644 --- a/command/login.go +++ b/command/login.go @@ -47,7 +47,7 @@ func (c LoginCommand) RunContext(ctx context.Context, globals *Globals, config * } oauthCfg.RedirectURL = fmt.Sprintf("http://%s", net.JoinHostPort("localhost", port)) - handler := oauth2.AuthorizationCodeHandler{Config: oauthCfg} + handler := &oauth2.AuthorizationCodeHandler{Config: oauthCfg} session := handler.NewSession() if !c.Browser { if isPiped() || globals.Quiet { @@ -59,12 +59,27 @@ func (c LoginCommand) RunContext(ctx context.Context, globals *Globals, config * browser.OpenURL(session.URL()) } - accessToken, idToken, err := handler.WaitForToken(ctx, sock, session) - if err != nil { + errCh := make(chan error, 1) + go func() { + err := http.Serve(sock, handler) + if err != nil && !errors.Is(err, http.ErrServerClosed) { + errCh <- err + } + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errCh: return err + case err := <-session.Error: + return err + case token := <-session.Token: + // TODO Will panic if id_token not present + // TODO Verify token with OIDC provider + idToken := token.Extra("id_token").(string) + return config.SaveOAuthToken(token, idToken) } - - return config.SaveOAuthToken(accessToken, idToken) } func (c LoginCommand) Run(globals *Globals, config *Config) error { diff --git a/oauth2/oauth2.go b/oauth2/oauth2.go index 122d69cd..4ada7dd0 100644 --- a/oauth2/oauth2.go +++ b/oauth2/oauth2.go @@ -6,9 +6,8 @@ import ( "encoding/base64" "errors" "fmt" - "net" "net/http" - "strings" + "sync" "github.com/RobotsAndPencils/go-saml" "github.com/coreos/go-oidc" @@ -54,73 +53,6 @@ func (o *OAuth2CallbackState) FromRequest(r *http.Request) { o.code = r.FormValue("code") } -// Verify safely compares the given state with the state from the OAuth2 callback. -// -// If they match, the code is returned, with a nil value. Otherwise, an empty string and an error is returned. -func (o OAuth2CallbackState) Verify(expectedState string) (string, error) { - if o.errorMessage != "" { - return "", OAuth2Error{Reason: o.errorMessage, Description: o.errorDescription} - } - - if strings.Compare(o.state, expectedState) != 0 { - return "", OAuth2Error{Reason: "invalid_state", Description: "state mismatch"} - } - - return o.code, nil -} - -type Callback struct { - Token *oauth2.Token - IDToken *string - Error error -} - -type CodeExchanger interface { - Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) -} - -// OAuth2CallbackHandler returns a http.Handler, channel and function triple. -// -// The http handler will accept exactly one request, which it will assume is an OAuth2 callback, parse it into an OAuth2CallbackState and then provide it to the given channel. Subsequent requests will be silently ignored. -// -// The function may be called to ensure that the channel is closed. The channel is closed when a request is received. In general, it is a good idea to ensure this function is called in a defer() block. -func OAuth2CallbackHandler(codeEx CodeExchanger, state, verifier string, ch chan<- Callback) http.Handler { - fn := func(w http.ResponseWriter, r *http.Request) { - - // This can sometimes be called multiple times, depending on the browser. - // We will simply ignore any other requests and only serve the first. - var info OAuth2CallbackState - info.FromRequest(r) - - code, err := info.Verify(state) - if err != nil { - ch <- Callback{Error: err} - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - - token, err := codeEx.Exchange(r.Context(), code, oauth2.VerifierOption(verifier)) - if err != nil { - ch <- Callback{Error: err} - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - // Make sure to respond to the user right away. If we don't, - // the server may be closed before a response can be sent. - fmt.Fprintln(w, "You may close this window now.") - - // https://openid.net/specs/openid-connect-core-1_0.html#TokenResponse - if idToken, ok := token.Extra("id_token").(string); ok { - ch <- Callback{Token: token, IDToken: &idToken} - } else { - ch <- Callback{Token: token} - } - } - - return http.HandlerFunc(fn) -} - type OAuth2Error struct { Reason string Description string @@ -140,6 +72,9 @@ type Session struct { url string state string verifier string + + Token chan *oauth2.Token + Error chan error } func (s Session) URL() string { @@ -147,39 +82,46 @@ func (s Session) URL() string { } type AuthorizationCodeHandler struct { - Config *oauth2.Config + Config *oauth2.Config + sessions map[string]Session + mu sync.Mutex } -func (r *AuthorizationCodeHandler) NewSession() Session { +func (h *AuthorizationCodeHandler) NewSession() Session { state := generateState() verifier := oauth2.GenerateVerifier() - url := r.Config.AuthCodeURL(state, oauth2.S256ChallengeOption(verifier)) - s := Session{verifier: verifier, state: state, url: url} - r.sessions[state] = s + url := h.Config.AuthCodeURL(state, oauth2.S256ChallengeOption(verifier)) + s := Session{verifier: verifier, state: state, url: url, Token: make(chan *oauth2.Token)} + h.mu.Lock() + defer h.mu.Unlock() + h.sessions[state] = s return s } -func (r AuthorizationCodeHandler) WaitForToken(ctx context.Context, listener net.Listener, session Session) (*oauth2.Token, string, error) { - ch := make(chan Callback, 1) - server := http.Server{ - Handler: OAuth2CallbackHandler(r.Config, session.state, session.verifier, ch), - } - - go server.Serve(listener) +func (h *AuthorizationCodeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + var info OAuth2CallbackState + info.FromRequest(r) - select { - case info := <-ch: - server.Close() - if info.Error != nil { - return nil, "", info.Error - } + h.mu.Lock() + defer h.mu.Unlock() + session, ok := h.sessions[info.state] + if !ok { + http.Error(w, "no session", http.StatusBadRequest) + return + } - return info.Token, "", nil - case <-ctx.Done(): - server.Close() - return nil, "", ctx.Err() + token, err := h.Config.Exchange(r.Context(), info.code, oauth2.VerifierOption(session.verifier)) + if err != nil { + session.Error <- err + http.Error(w, err.Error(), http.StatusInternalServerError) + return } + + // Make sure to respond to the user right away. If we don't, + // the server may be closed before a response can be sent. + fmt.Fprintln(w, "You may close this window now.") + session.Token <- token } func DiscoverConfigAndExchangeTokenForAssertion(ctx context.Context, accessToken, idToken, oidcDomain, clientID, applicationID string) (*saml.Response, string, error) {