From 80673b4a4bfc6c2c58a0b44cf9106913fe293994 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Mon, 14 Jan 2019 22:52:49 +0000 Subject: [PATCH] oauth2: auto-detect auth style by default, add Endpoint.AuthStyle Instead of maintaining a global map of which OAuth2 servers do which auth style and/or requiring the user to tell us, just try both ways and remember which way worked. But if users want to tell us in the Endpoint, this CL also add Endpoint.AuthStyle. Fixes golang/oauth2#111 Fixes golang/oauth2#365 Fixes golang/oauth2#362 Fixes golang/oauth2#357 Fixes golang/oauth2#353 Fixes golang/oauth2#345 Fixes golang/oauth2#326 Fixes golang/oauth2#352 Fixes golang/oauth2#268 Fixes https://go-review.googlesource.com/c/oauth2/+/58510 (... and surely many more ...) Change-Id: I7b4d98ba1900ee2d3e11e629316b0bf867f7d237 Reviewed-on: https://go-review.googlesource.com/c/157820 Run-TryBot: Brad Fitzpatrick TryBot-Result: Gobot Gobot Reviewed-by: Ross Light --- clientcredentials/clientcredentials.go | 8 +- clientcredentials/clientcredentials_test.go | 17 +- google/google.go | 5 +- internal/token.go | 210 +++++++++++--------- internal/token_test.go | 59 +----- linkedin/linkedin.go | 5 +- oauth2.go | 45 +++-- oauth2_test.go | 32 +-- token.go | 2 +- 9 files changed, 200 insertions(+), 183 deletions(-) diff --git a/clientcredentials/clientcredentials.go b/clientcredentials/clientcredentials.go index 081296492..7a0b9ed10 100644 --- a/clientcredentials/clientcredentials.go +++ b/clientcredentials/clientcredentials.go @@ -42,6 +42,11 @@ type Config struct { // EndpointParams specifies additional parameters for requests to the token endpoint. EndpointParams url.Values + + // AuthStyle optionally specifies how the endpoint wants the + // client ID & client secret sent. The zero value means to + // auto-detect. + AuthStyle oauth2.AuthStyle } // Token uses client credentials to retrieve a token. @@ -97,7 +102,8 @@ func (c *tokenSource) Token() (*oauth2.Token, error) { } v[k] = p } - tk, err := internal.RetrieveToken(c.ctx, c.conf.ClientID, c.conf.ClientSecret, c.conf.TokenURL, v) + + tk, err := internal.RetrieveToken(c.ctx, c.conf.ClientID, c.conf.ClientSecret, c.conf.TokenURL, v, internal.AuthStyle(c.conf.AuthStyle)) if err != nil { if rErr, ok := err.(*internal.RetrieveError); ok { return nil, (*oauth2.RetrieveError)(rErr) diff --git a/clientcredentials/clientcredentials_test.go b/clientcredentials/clientcredentials_test.go index 35cbd5499..02a1c89a8 100644 --- a/clientcredentials/clientcredentials_test.go +++ b/clientcredentials/clientcredentials_test.go @@ -6,11 +6,14 @@ package clientcredentials import ( "context" + "io" "io/ioutil" "net/http" "net/http/httptest" "net/url" "testing" + + "golang.org/x/oauth2/internal" ) func newConf(serverURL string) *Config { @@ -111,21 +114,25 @@ func TestTokenRequest(t *testing.T) { } func TestTokenRefreshRequest(t *testing.T) { + internal.ResetAuthCache() ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.String() == "/somethingelse" { return } if r.URL.String() != "/token" { - t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL) + t.Errorf("Unexpected token refresh request URL: %q", r.URL) } headerContentType := r.Header.Get("Content-Type") - if headerContentType != "application/x-www-form-urlencoded" { - t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType) + if got, want := headerContentType, "application/x-www-form-urlencoded"; got != want { + t.Errorf("Content-Type = %q; want %q", got, want) } body, _ := ioutil.ReadAll(r.Body) - if string(body) != "audience=audience1&grant_type=client_credentials&scope=scope1+scope2" { - t.Errorf("Unexpected refresh token payload, %v is found.", string(body)) + const want = "audience=audience1&grant_type=client_credentials&scope=scope1+scope2" + if string(body) != want { + t.Errorf("Unexpected refresh token payload.\n got: %s\nwant: %s\n", body, want) } + w.Header().Set("Content-Type", "application/json") + io.WriteString(w, `{"access_token": "foo", "refresh_token": "bar"}`) })) defer ts.Close() conf := newConf(ts.URL) diff --git a/google/google.go b/google/google.go index ca7d208d7..4b0b54720 100644 --- a/google/google.go +++ b/google/google.go @@ -19,8 +19,9 @@ import ( // Endpoint is Google's OAuth 2.0 endpoint. var Endpoint = oauth2.Endpoint{ - AuthURL: "https://accounts.google.com/o/oauth2/auth", - TokenURL: "https://accounts.google.com/o/oauth2/token", + AuthURL: "https://accounts.google.com/o/oauth2/auth", + TokenURL: "https://accounts.google.com/o/oauth2/token", + AuthStyle: oauth2.AuthStyleInParams, } // JWTTokenURL is Google's OAuth 2.0 token URL to use with the JWT flow. diff --git a/internal/token.go b/internal/token.go index a831b7746..0f75a182f 100644 --- a/internal/token.go +++ b/internal/token.go @@ -16,6 +16,7 @@ import ( "net/url" "strconv" "strings" + "sync" "time" "golang.org/x/net/context/ctxhttp" @@ -90,102 +91,71 @@ func (e *expirationTime) UnmarshalJSON(b []byte) error { return nil } -var brokenAuthHeaderProviders = []string{ - "https://accounts.google.com/", - "https://api.codeswholesale.com/oauth/token", - "https://api.dropbox.com/", - "https://api.dropboxapi.com/", - "https://api.instagram.com/", - "https://api.netatmo.net/", - "https://api.odnoklassniki.ru/", - "https://api.pushbullet.com/", - "https://api.soundcloud.com/", - "https://api.twitch.tv/", - "https://id.twitch.tv/", - "https://app.box.com/", - "https://api.box.com/", - "https://connect.stripe.com/", - "https://login.mailchimp.com/", - "https://login.microsoftonline.com/", - "https://login.salesforce.com/", - "https://login.windows.net", - "https://login.live.com/", - "https://login.live-int.com/", - "https://oauth.sandbox.trainingpeaks.com/", - "https://oauth.trainingpeaks.com/", - "https://oauth.vk.com/", - "https://openapi.baidu.com/", - "https://slack.com/", - "https://test-sandbox.auth.corp.google.com", - "https://test.salesforce.com/", - "https://user.gini.net/", - "https://www.douban.com/", - "https://www.googleapis.com/", - "https://www.linkedin.com/", - "https://www.strava.com/oauth/", - "https://www.wunderlist.com/oauth/", - "https://api.patreon.com/", - "https://sandbox.codeswholesale.com/oauth/token", - "https://api.sipgate.com/v1/authorization/oauth", - "https://api.medium.com/v1/tokens", - "https://log.finalsurge.com/oauth/token", - "https://multisport.todaysplan.com.au/rest/oauth/access_token", - "https://whats.todaysplan.com.au/rest/oauth/access_token", - "https://stackoverflow.com/oauth/access_token", - "https://account.health.nokia.com", - "https://accounts.zoho.com", - "https://gitter.im/login/oauth/token", - "https://openid-connect.onelogin.com/oidc", - "https://api.dailymotion.com/oauth/token", +// RegisterBrokenAuthHeaderProvider previously did something. It is now a no-op. +// +// Deprecated: this function no longer does anything. Caller code that +// wants to avoid potential extra HTTP requests made during +// auto-probing of the provider's auth style should set +// Endpoint.AuthStyle. +func RegisterBrokenAuthHeaderProvider(tokenURL string) {} + +// AuthStyle is a copy of the golang.org/x/oauth2 package's AuthStyle type. +type AuthStyle int + +const ( + AuthStyleUnknown AuthStyle = 0 + AuthStyleInParams AuthStyle = 1 + AuthStyleInHeader AuthStyle = 2 +) + +// authStyleCache is the set of tokenURLs we've successfully used via +// RetrieveToken and which style auth we ended up using. +// It's called a cache, but it doesn't (yet?) shrink. It's expected that +// the set of OAuth2 servers a program contacts over time is fixed and +// small. +var authStyleCache struct { + sync.Mutex + m map[string]AuthStyle // keyed by tokenURL } -// brokenAuthHeaderDomains lists broken providers that issue dynamic endpoints. -var brokenAuthHeaderDomains = []string{ - ".auth0.com", - ".force.com", - ".myshopify.com", - ".okta.com", - ".oktapreview.com", +// ResetAuthCache resets the global authentication style cache used +// for AuthStyleUnknown token requests. +func ResetAuthCache() { + authStyleCache.Lock() + defer authStyleCache.Unlock() + authStyleCache.m = nil } -func RegisterBrokenAuthHeaderProvider(tokenURL string) { - brokenAuthHeaderProviders = append(brokenAuthHeaderProviders, tokenURL) +// lookupAuthStyle reports which auth style we last used with tokenURL +// when calling RetrieveToken and whether we have ever done so. +func lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) { + authStyleCache.Lock() + defer authStyleCache.Unlock() + style, ok = authStyleCache.m[tokenURL] + return } -// providerAuthHeaderWorks reports whether the OAuth2 server identified by the tokenURL -// implements the OAuth2 spec correctly -// See https://code.google.com/p/goauth2/issues/detail?id=31 for background. -// In summary: -// - Reddit only accepts client secret in the Authorization header -// - Dropbox accepts either it in URL param or Auth header, but not both. -// - Google only accepts URL param (not spec compliant?), not Auth header -// - Stripe only accepts client secret in Auth header with Bearer method, not Basic -func providerAuthHeaderWorks(tokenURL string) bool { - for _, s := range brokenAuthHeaderProviders { - if strings.HasPrefix(tokenURL, s) { - // Some sites fail to implement the OAuth2 spec fully. - return false - } +// setAuthStyle adds an entry to authStyleCache, documented above. +func setAuthStyle(tokenURL string, v AuthStyle) { + authStyleCache.Lock() + defer authStyleCache.Unlock() + if authStyleCache.m == nil { + authStyleCache.m = make(map[string]AuthStyle) } - - if u, err := url.Parse(tokenURL); err == nil { - for _, s := range brokenAuthHeaderDomains { - if strings.HasSuffix(u.Host, s) { - return false - } - } - } - - // Assume the provider implements the spec properly - // otherwise. We can add more exceptions as they're - // discovered. We will _not_ be adding configurable hooks - // to this package to let users select server bugs. - return true + authStyleCache.m[tokenURL] = v } -func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values) (*Token, error) { - bustedAuth := !providerAuthHeaderWorks(tokenURL) - if bustedAuth { +// newTokenRequest returns a new *http.Request to retrieve a new token +// from tokenURL using the provided clientID, clientSecret, and POST +// body parameters. +// +// inParams is whether the clientID & clientSecret should be encoded +// as the POST body. An 'inParams' value of true means to send it in +// the POST body (along with any values in v); false means to send it +// in the Authorization header. +func newTokenRequest(tokenURL, clientID, clientSecret string, v url.Values, authStyle AuthStyle) (*http.Request, error) { + if authStyle == AuthStyleInParams { + v = cloneURLValues(v) if clientID != "" { v.Set("client_id", clientID) } @@ -198,15 +168,70 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, return nil, err } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - if !bustedAuth { + if authStyle == AuthStyleInHeader { req.SetBasicAuth(url.QueryEscape(clientID), url.QueryEscape(clientSecret)) } + return req, nil +} + +func cloneURLValues(v url.Values) url.Values { + v2 := make(url.Values, len(v)) + for k, vv := range v { + v2[k] = append([]string(nil), vv...) + } + return v2 +} + +func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle) (*Token, error) { + needsAuthStyleProbe := authStyle == 0 + if needsAuthStyleProbe { + if style, ok := lookupAuthStyle(tokenURL); ok { + authStyle = style + needsAuthStyleProbe = false + } else { + authStyle = AuthStyleInHeader // the first way we'll try + } + } + req, err := newTokenRequest(tokenURL, clientID, clientSecret, v, authStyle) + if err != nil { + return nil, err + } + token, err := doTokenRoundTrip(ctx, req) + if err != nil && needsAuthStyleProbe { + // If we get an error, assume the server wants the + // clientID & clientSecret in a different form. + // See https://code.google.com/p/goauth2/issues/detail?id=31 for background. + // In summary: + // - Reddit only accepts client secret in the Authorization header + // - Dropbox accepts either it in URL param or Auth header, but not both. + // - Google only accepts URL param (not spec compliant?), not Auth header + // - Stripe only accepts client secret in Auth header with Bearer method, not Basic + // + // We used to maintain a big table in this code of all the sites and which way + // they went, but maintaining it didn't scale & got annoying. + // So just try both ways. + authStyle = AuthStyleInParams // the second way we'll try + req, _ = newTokenRequest(tokenURL, clientID, clientSecret, v, authStyle) + token, err = doTokenRoundTrip(ctx, req) + } + if needsAuthStyleProbe && err == nil { + setAuthStyle(tokenURL, authStyle) + } + // Don't overwrite `RefreshToken` with an empty value + // if this was a token refreshing request. + if token != nil && token.RefreshToken == "" { + token.RefreshToken = v.Get("refresh_token") + } + return token, err +} + +func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) { r, err := ctxhttp.Do(ctx, ContextClient(ctx), req) if err != nil { return nil, err } - defer r.Body.Close() body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20)) + r.Body.Close() if err != nil { return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) } @@ -256,13 +281,8 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, } json.Unmarshal(body, &token.Raw) // no error checks for optional fields } - // Don't overwrite `RefreshToken` with an empty value - // if this was a token refreshing request. - if token.RefreshToken == "" { - token.RefreshToken = v.Get("refresh_token") - } if token.AccessToken == "" { - return token, errors.New("oauth2: server response missing access_token") + return nil, errors.New("oauth2: server response missing access_token") } return token, nil } diff --git a/internal/token_test.go b/internal/token_test.go index d1da8bb04..d8373c25a 100644 --- a/internal/token_test.go +++ b/internal/token_test.go @@ -6,7 +6,6 @@ package internal import ( "context" - "fmt" "io" "net/http" "net/http/httptest" @@ -14,17 +13,9 @@ import ( "testing" ) -func TestRegisterBrokenAuthHeaderProvider(t *testing.T) { - RegisterBrokenAuthHeaderProvider("https://aaa.com/") - tokenURL := "https://aaa.com/token" - if providerAuthHeaderWorks(tokenURL) { - t.Errorf("got %q as unbroken; want broken", tokenURL) - } -} - -func TestRetrieveTokenBustedNoSecret(t *testing.T) { +func TestRetrieveToken_InParams(t *testing.T) { + ResetAuthCache() const clientID = "client-id" - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if got, want := r.FormValue("client_id"), clientID; got != want { t.Errorf("client_id = %q; want %q", got, want) @@ -36,52 +27,14 @@ func TestRetrieveTokenBustedNoSecret(t *testing.T) { io.WriteString(w, `{"access_token": "ACCESS_TOKEN", "token_type": "bearer"}`) })) defer ts.Close() - - RegisterBrokenAuthHeaderProvider(ts.URL) - _, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}) + _, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleInParams) if err != nil { t.Errorf("RetrieveToken = %v; want no error", err) } } -func Test_providerAuthHeaderWorks(t *testing.T) { - for _, p := range brokenAuthHeaderProviders { - if providerAuthHeaderWorks(p) { - t.Errorf("got %q as unbroken; want broken", p) - } - p := fmt.Sprintf("%ssomesuffix", p) - if providerAuthHeaderWorks(p) { - t.Errorf("got %q as unbroken; want broken", p) - } - } - p := "https://api.not-in-the-list-example.com/" - if !providerAuthHeaderWorks(p) { - t.Errorf("got %q as unbroken; want broken", p) - } -} - -func TestProviderAuthHeaderWorksDomain(t *testing.T) { - tests := []struct { - tokenURL string - wantWorks bool - }{ - {"https://dev-12345.okta.com/token-url", false}, - {"https://dev-12345.oktapreview.com/token-url", false}, - {"https://dev-12345.okta.org/token-url", true}, - {"https://foo.bar.force.com/token-url", false}, - {"https://foo.force.com/token-url", false}, - {"https://force.com/token-url", true}, - } - - for _, test := range tests { - got := providerAuthHeaderWorks(test.tokenURL) - if got != test.wantWorks { - t.Errorf("providerAuthHeaderWorks(%q) = %v; want %v", test.tokenURL, got, test.wantWorks) - } - } -} - func TestRetrieveTokenWithContexts(t *testing.T) { + ResetAuthCache() const clientID = "client-id" ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -90,7 +43,7 @@ func TestRetrieveTokenWithContexts(t *testing.T) { })) defer ts.Close() - _, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}) + _, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleUnknown) if err != nil { t.Errorf("RetrieveToken (with background context) = %v; want no error", err) } @@ -103,7 +56,7 @@ func TestRetrieveTokenWithContexts(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() - _, err = RetrieveToken(ctx, clientID, "", cancellingts.URL, url.Values{}) + _, err = RetrieveToken(ctx, clientID, "", cancellingts.URL, url.Values{}, AuthStyleUnknown) close(retrieved) if err == nil { t.Errorf("RetrieveToken (with cancelled context) = nil; want error") diff --git a/linkedin/linkedin.go b/linkedin/linkedin.go index 62d1de817..d3972771c 100644 --- a/linkedin/linkedin.go +++ b/linkedin/linkedin.go @@ -11,6 +11,7 @@ import ( // Endpoint is LinkedIn's OAuth 2.0 endpoint. var Endpoint = oauth2.Endpoint{ - AuthURL: "https://www.linkedin.com/oauth/v2/authorization", - TokenURL: "https://www.linkedin.com/oauth/v2/accessToken", + AuthURL: "https://www.linkedin.com/oauth/v2/authorization", + TokenURL: "https://www.linkedin.com/oauth/v2/accessToken", + AuthStyle: oauth2.AuthStyleInParams, } diff --git a/oauth2.go b/oauth2.go index 3de63315b..ec6ee004c 100644 --- a/oauth2.go +++ b/oauth2.go @@ -26,17 +26,13 @@ import ( // Deprecated: Use context.Background() or context.TODO() instead. var NoContext = context.TODO() -// RegisterBrokenAuthHeaderProvider registers an OAuth2 server -// identified by the tokenURL prefix as an OAuth2 implementation -// which doesn't support the HTTP Basic authentication -// scheme to authenticate with the authorization server. -// Once a server is registered, credentials (client_id and client_secret) -// will be passed as parameters in the request body rather than being present -// in the Authorization header. -// See https://code.google.com/p/goauth2/issues/detail?id=31 for background. -func RegisterBrokenAuthHeaderProvider(tokenURL string) { - internal.RegisterBrokenAuthHeaderProvider(tokenURL) -} +// RegisterBrokenAuthHeaderProvider previously did something. It is now a no-op. +// +// Deprecated: this function no longer does anything. Caller code that +// wants to avoid potential extra HTTP requests made during +// auto-probing of the provider's auth style should set +// Endpoint.AuthStyle. +func RegisterBrokenAuthHeaderProvider(tokenURL string) {} // Config describes a typical 3-legged OAuth2 flow, with both the // client application information and the server's endpoint URLs. @@ -71,13 +67,38 @@ type TokenSource interface { Token() (*Token, error) } -// Endpoint contains the OAuth 2.0 provider's authorization and token +// Endpoint represents an OAuth 2.0 provider's authorization and token // endpoint URLs. type Endpoint struct { AuthURL string TokenURL string + + // AuthStyle optionally specifies how the endpoint wants the + // client ID & client secret sent. The zero value means to + // auto-detect. + AuthStyle AuthStyle } +// AuthStyle represents how requests for tokens are authenticated +// to the server. +type AuthStyle int + +const ( + // AuthStyleAutoDetect means to auto-detect which authentication + // style the provider wants by trying both ways and caching + // the successful way for the future. + AuthStyleAutoDetect AuthStyle = 0 + + // AuthStyleInParams sends the "client_id" and "client_secret" + // in the POST body as application/x-www-form-urlencoded parameters. + AuthStyleInParams AuthStyle = 1 + + // AuthStyleInHeader sends the client_id and client_password + // using HTTP Basic Authorization. This is an optional style + // described in the OAuth2 RFC 6749 section 2.3.1. + AuthStyleInHeader AuthStyle = 2 +) + var ( // AccessTypeOnline and AccessTypeOffline are options passed // to the Options.AuthCodeURL method. They modify the diff --git a/oauth2_test.go b/oauth2_test.go index 19aaf6b2b..a059b8b37 100644 --- a/oauth2_test.go +++ b/oauth2_test.go @@ -8,12 +8,15 @@ import ( "context" "errors" "fmt" + "io" "io/ioutil" "net/http" "net/http/httptest" "net/url" "testing" "time" + + "golang.org/x/oauth2/internal" ) type mockTransport struct { @@ -93,22 +96,22 @@ func TestURLUnsafeClientConfig(t *testing.T) { func TestExchangeRequest(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.String() != "/token" { - t.Errorf("Unexpected exchange request URL, %v is found.", r.URL) + t.Errorf("Unexpected exchange request URL %q", r.URL) } headerAuth := r.Header.Get("Authorization") - if headerAuth != "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=" { - t.Errorf("Unexpected authorization header, %v is found.", headerAuth) + if want := "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ="; headerAuth != want { + t.Errorf("Unexpected authorization header %q, want %q", headerAuth, want) } headerContentType := r.Header.Get("Content-Type") if headerContentType != "application/x-www-form-urlencoded" { - t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType) + t.Errorf("Unexpected Content-Type header %q", headerContentType) } body, err := ioutil.ReadAll(r.Body) if err != nil { t.Errorf("Failed reading request body: %s.", err) } if string(body) != "code=exchange-code&grant_type=authorization_code&redirect_uri=REDIRECT_URL" { - t.Errorf("Unexpected exchange payload, %v is found.", string(body)) + t.Errorf("Unexpected exchange payload; got %q", body) } w.Header().Set("Content-Type", "application/x-www-form-urlencoded") w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer")) @@ -343,11 +346,12 @@ func TestExchangeRequest_BadResponseType(t *testing.T) { } func TestExchangeRequest_NonBasicAuth(t *testing.T) { + internal.ResetAuthCache() tr := &mockTransport{ rt: func(r *http.Request) (w *http.Response, err error) { headerAuth := r.Header.Get("Authorization") if headerAuth != "" { - t.Errorf("Unexpected authorization header, %v is found.", headerAuth) + t.Errorf("Unexpected authorization header %q", headerAuth) } return nil, errors.New("no response") }, @@ -356,8 +360,9 @@ func TestExchangeRequest_NonBasicAuth(t *testing.T) { conf := &Config{ ClientID: "CLIENT_ID", Endpoint: Endpoint{ - AuthURL: "https://accounts.google.com/auth", - TokenURL: "https://accounts.google.com/token", + AuthURL: "https://accounts.google.com/auth", + TokenURL: "https://accounts.google.com/token", + AuthStyle: AuthStyleInParams, }, } @@ -413,21 +418,24 @@ func TestPasswordCredentialsTokenRequest(t *testing.T) { } func TestTokenRefreshRequest(t *testing.T) { + internal.ResetAuthCache() ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.String() == "/somethingelse" { return } if r.URL.String() != "/token" { - t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL) + t.Errorf("Unexpected token refresh request URL %q", r.URL) } headerContentType := r.Header.Get("Content-Type") if headerContentType != "application/x-www-form-urlencoded" { - t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType) + t.Errorf("Unexpected Content-Type header %q", headerContentType) } body, _ := ioutil.ReadAll(r.Body) if string(body) != "grant_type=refresh_token&refresh_token=REFRESH_TOKEN" { - t.Errorf("Unexpected refresh token payload, %v is found.", string(body)) + t.Errorf("Unexpected refresh token payload %q", body) } + w.Header().Set("Content-Type", "application/json") + io.WriteString(w, `{"access_token": "foo", "refresh_token": "bar"}`) })) defer ts.Close() conf := newConf(ts.URL) @@ -478,7 +486,7 @@ func TestTokenRetrieveError(t *testing.T) { } _, ok := err.(*RetrieveError) if !ok { - t.Fatalf("got %T error, expected *RetrieveError", err) + t.Fatalf("got %T error, expected *RetrieveError; error was: %v", err, err) } // Test error string for backwards compatibility expected := fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", "400 Bad Request", `{"error": "invalid_grant"}`) diff --git a/token.go b/token.go index ee4be545f..822720341 100644 --- a/token.go +++ b/token.go @@ -154,7 +154,7 @@ func tokenFromInternal(t *internal.Token) *Token { // This token is then mapped from *internal.Token into an *oauth2.Token which is returned along // with an error.. func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) { - tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v) + tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v, internal.AuthStyle(c.Endpoint.AuthStyle)) if err != nil { if rErr, ok := err.(*internal.RetrieveError); ok { return nil, (*RetrieveError)(rErr)