diff --git a/cmd/example-app/main.go b/cmd/example-app/main.go index 4da34b9a8a..d527bbc8e8 100644 --- a/cmd/example-app/main.go +++ b/cmd/example-app/main.go @@ -37,8 +37,7 @@ type app struct { // or does it use "access_type=offline" (e.g. Google)? offlineAsScope bool - ctx context.Context - cancel context.CancelFunc + client *http.Client } // return an HTTP client which trusts the provided root CAs. @@ -118,31 +117,31 @@ func cmd() *cobra.Command { return fmt.Errorf("parse listen address: %v", err) } - a.ctx, a.cancel = context.WithCancel(context.Background()) - if rootCAs != "" { client, err := httpClientForRootCAs(rootCAs) if err != nil { return err } - - // This sets the OAuth2 client and oidc client. - a.ctx = context.WithValue(a.ctx, oauth2.HTTPClient, client) + a.client = client } if debug { - client, ok := a.ctx.Value(oauth2.HTTPClient).(*http.Client) - if ok { - client.Transport = debugTransport{client.Transport} - } else { - a.ctx = context.WithValue(a.ctx, oauth2.HTTPClient, &http.Client{ + if a.client == nil { + a.client = &http.Client{ Transport: debugTransport{http.DefaultTransport}, - }) + } + } else { + a.client.Transport = debugTransport{a.client.Transport} } } + if a.client == nil { + a.client = http.DefaultClient + } + // TODO(ericchiang): Retry with backoff - provider, err := oidc.NewProvider(a.ctx, issuerURL) + ctx := oidc.ClientContext(context.Background(), a.client) + provider, err := oidc.NewProvider(ctx, issuerURL) if err != nil { return fmt.Errorf("Failed to query provider %q: %v", issuerURL, err) } @@ -258,6 +257,8 @@ func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) { err error token *oauth2.Token ) + + ctx := oidc.ClientContext(r.Context(), a.client) oauth2Config := a.oauth2Config(nil) switch r.Method { case "GET": @@ -275,7 +276,7 @@ func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) { http.Error(w, fmt.Sprintf("expected state %q got %q", exampleAppState, state), http.StatusBadRequest) return } - token, err = oauth2Config.Exchange(a.ctx, code) + token, err = oauth2Config.Exchange(ctx, code) case "POST": // Form request from frontend to refresh a token. refresh := r.FormValue("refresh_token") @@ -287,7 +288,7 @@ func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) { RefreshToken: refresh, Expiry: time.Now().Add(-time.Hour), } - token, err = oauth2Config.TokenSource(r.Context(), t).Token() + token, err = oauth2Config.TokenSource(ctx, t).Token() default: http.Error(w, fmt.Sprintf("method not implemented: %s", r.Method), http.StatusBadRequest) return