Skip to content

Commit

Permalink
Use channels to clean up AuthorizationCodeHandler
Browse files Browse the repository at this point in the history
This change makes AuthorizationCodeHandler reentrant and prevents
panics that could occur if two requests occurred before the server was
closed.

It also makes the API nicer.
  • Loading branch information
punmechanic committed Dec 6, 2024
1 parent 384ffc9 commit ed49529
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 96 deletions.
25 changes: 20 additions & 5 deletions command/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func (c LoginCommand) RunContext(ctx context.Context, globals *Globals, config *
}
oauthCfg.RedirectURL = fmt.Sprintf("http://%s", net.JoinHostPort("localhost", port))

handler := oauth2.AuthorizationCodeHandler{Config: oauthCfg}
handler := &oauth2.AuthorizationCodeHandler{Config: oauthCfg}
session := handler.NewSession()
if !c.Browser {
if isPiped() || globals.Quiet {
Expand All @@ -59,12 +59,27 @@ func (c LoginCommand) RunContext(ctx context.Context, globals *Globals, config *
browser.OpenURL(session.URL())
}

accessToken, idToken, err := handler.WaitForToken(ctx, sock, session)
if err != nil {
errCh := make(chan error, 1)
go func() {
err := http.Serve(sock, handler)
if err != nil && !errors.Is(err, http.ErrServerClosed) {
errCh <- err
}
}()

select {
case <-ctx.Done():
return ctx.Err()
case err := <-errCh:
return err
case err := <-session.Error:
return err
case token := <-session.Token:
// TODO Will panic if id_token not present
// TODO Verify token with OIDC provider
idToken := token.Extra("id_token").(string)
return config.SaveOAuthToken(token, idToken)
}

return config.SaveOAuthToken(accessToken, idToken)
}

func (c LoginCommand) Run(globals *Globals, config *Config) error {
Expand Down
124 changes: 33 additions & 91 deletions oauth2/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@ import (
"encoding/base64"
"errors"
"fmt"
"net"
"net/http"
"strings"
"sync"

"github.com/RobotsAndPencils/go-saml"
"github.com/coreos/go-oidc"
Expand Down Expand Up @@ -54,73 +53,6 @@ func (o *OAuth2CallbackState) FromRequest(r *http.Request) {
o.code = r.FormValue("code")
}

// Verify safely compares the given state with the state from the OAuth2 callback.
//
// If they match, the code is returned, with a nil value. Otherwise, an empty string and an error is returned.
func (o OAuth2CallbackState) Verify(expectedState string) (string, error) {
if o.errorMessage != "" {
return "", OAuth2Error{Reason: o.errorMessage, Description: o.errorDescription}
}

if strings.Compare(o.state, expectedState) != 0 {
return "", OAuth2Error{Reason: "invalid_state", Description: "state mismatch"}
}

return o.code, nil
}

type Callback struct {
Token *oauth2.Token
IDToken *string
Error error
}

type CodeExchanger interface {
Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error)
}

// OAuth2CallbackHandler returns a http.Handler, channel and function triple.
//
// The http handler will accept exactly one request, which it will assume is an OAuth2 callback, parse it into an OAuth2CallbackState and then provide it to the given channel. Subsequent requests will be silently ignored.
//
// The function may be called to ensure that the channel is closed. The channel is closed when a request is received. In general, it is a good idea to ensure this function is called in a defer() block.
func OAuth2CallbackHandler(codeEx CodeExchanger, state, verifier string, ch chan<- Callback) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {

// This can sometimes be called multiple times, depending on the browser.
// We will simply ignore any other requests and only serve the first.
var info OAuth2CallbackState
info.FromRequest(r)

code, err := info.Verify(state)
if err != nil {
ch <- Callback{Error: err}
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

token, err := codeEx.Exchange(r.Context(), code, oauth2.VerifierOption(verifier))
if err != nil {
ch <- Callback{Error: err}
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

// Make sure to respond to the user right away. If we don't,
// the server may be closed before a response can be sent.
fmt.Fprintln(w, "You may close this window now.")

// https://openid.net/specs/openid-connect-core-1_0.html#TokenResponse
if idToken, ok := token.Extra("id_token").(string); ok {
ch <- Callback{Token: token, IDToken: &idToken}
} else {
ch <- Callback{Token: token}
}
}

return http.HandlerFunc(fn)
}

type OAuth2Error struct {
Reason string
Description string
Expand All @@ -140,46 +72,56 @@ type Session struct {
url string
state string
verifier string

Token chan *oauth2.Token
Error chan error
}

func (s Session) URL() string {
return s.url
}

type AuthorizationCodeHandler struct {
Config *oauth2.Config
Config *oauth2.Config

sessions map[string]Session
mu sync.Mutex
}

func (r *AuthorizationCodeHandler) NewSession() Session {
func (h *AuthorizationCodeHandler) NewSession() Session {
state := generateState()
verifier := oauth2.GenerateVerifier()
url := r.Config.AuthCodeURL(state, oauth2.S256ChallengeOption(verifier))
s := Session{verifier: verifier, state: state, url: url}
r.sessions[state] = s
url := h.Config.AuthCodeURL(state, oauth2.S256ChallengeOption(verifier))
s := Session{verifier: verifier, state: state, url: url, Token: make(chan *oauth2.Token)}
h.mu.Lock()
defer h.mu.Unlock()
h.sessions[state] = s
return s
}

func (r AuthorizationCodeHandler) WaitForToken(ctx context.Context, listener net.Listener, session Session) (*oauth2.Token, string, error) {
ch := make(chan Callback, 1)
server := http.Server{
Handler: OAuth2CallbackHandler(r.Config, session.state, session.verifier, ch),
}

go server.Serve(listener)
func (h *AuthorizationCodeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var info OAuth2CallbackState
info.FromRequest(r)

select {
case info := <-ch:
server.Close()
if info.Error != nil {
return nil, "", info.Error
}
h.mu.Lock()
defer h.mu.Unlock()
session, ok := h.sessions[info.state]
if !ok {
http.Error(w, "no session", http.StatusBadRequest)
return
}

return info.Token, "", nil
case <-ctx.Done():
server.Close()
return nil, "", ctx.Err()
token, err := h.Config.Exchange(r.Context(), info.code, oauth2.VerifierOption(session.verifier))
if err != nil {
session.Error <- err
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

// Make sure to respond to the user right away. If we don't,
// the server may be closed before a response can be sent.
fmt.Fprintln(w, "You may close this window now.")
session.Token <- token
}

func DiscoverConfigAndExchangeTokenForAssertion(ctx context.Context, accessToken, idToken, oidcDomain, clientID, applicationID string) (*saml.Response, string, error) {
Expand Down

0 comments on commit ed49529

Please sign in to comment.