Skip to content

Commit

Permalink
Ensure the server is closed after receiving a request
Browse files Browse the repository at this point in the history
  • Loading branch information
punmechanic committed Dec 4, 2024
1 parent 3be4301 commit 384ffc9
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions oauth2/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ type CodeExchanger interface {
// The function may be called to ensure that the channel is closed. The channel is closed when a request is received. In general, it is a good idea to ensure this function is called in a defer() block.
func OAuth2CallbackHandler(codeEx CodeExchanger, state, verifier string, ch chan<- Callback) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {

// This can sometimes be called multiple times, depending on the browser.
// We will simply ignore any other requests and only serve the first.
var info OAuth2CallbackState
Expand All @@ -105,14 +106,16 @@ func OAuth2CallbackHandler(codeEx CodeExchanger, state, verifier string, ch chan
return
}

// Make sure to respond to the user right away. If we don't,
// the server may be closed before a response can be sent.
fmt.Fprintln(w, "You may close this window now.")

// https://openid.net/specs/openid-connect-core-1_0.html#TokenResponse
if idToken, ok := token.Extra("id_token").(string); ok {
ch <- Callback{Token: token, IDToken: &idToken}
} else {
ch <- Callback{Token: token}
}

fmt.Fprintln(w, "You may close this window now.")
}

return http.HandlerFunc(fn)
Expand Down Expand Up @@ -159,18 +162,22 @@ func (r *AuthorizationCodeHandler) NewSession() Session {

func (r AuthorizationCodeHandler) WaitForToken(ctx context.Context, listener net.Listener, session Session) (*oauth2.Token, string, error) {
ch := make(chan Callback, 1)
// TODO: This error probably should not be ignored if it is not http.ErrServerClosed
go http.Serve(listener, OAuth2CallbackHandler(r.Config, session.state, session.verifier, ch))
server := http.Server{
Handler: OAuth2CallbackHandler(r.Config, session.state, session.verifier, ch),
}

go server.Serve(listener)

select {
case info := <-ch:
// TODO: Close the server immediately to prevent any more requests being received.
server.Close()
if info.Error != nil {
return nil, "", info.Error
}

return info.Token, "", nil
case <-ctx.Done():
server.Close()
return nil, "", ctx.Err()
}
}
Expand Down

0 comments on commit 384ffc9

Please sign in to comment.