From c14691a3c13634f783814df30d7e1ab647b6adc6 Mon Sep 17 00:00:00 2001 From: Cody Oss Date: Tue, 22 Aug 2023 09:47:31 -0500 Subject: [PATCH 1/5] feat(auth): add base auth package This package provides: - A TokenProvider interface - A Token type - A standard auth Error type - Configuration for 2L0 oauth2 flows - Configuration for 3L0 oauth2 flows This code has been adapted from the golang oauth2 repo. It should feel familiar for anyone who as worked with that library before, but it only provdies a sub-set of the features that we require for our client libraries. --- .release-please-manifest-individual.json | 3 +- auth/README.md | 4 + auth/auth.go | 341 ++++++++++++++++ auth/auth_test.go | 445 +++++++++++++++++++++ auth/example_test.go | 58 +++ auth/go.mod | 3 + auth/internal/internal.go | 112 ++++++ auth/internal/jwt/jwt.go | 166 ++++++++ auth/internal/jwt/jwt_test.go | 79 ++++ auth/threelegged.go | 335 ++++++++++++++++ auth/threelegged_test.go | 475 +++++++++++++++++++++++ go.work | 1 + release-please-config-individual.json | 3 + 13 files changed, 2024 insertions(+), 1 deletion(-) create mode 100644 auth/README.md create mode 100644 auth/auth.go create mode 100644 auth/auth_test.go create mode 100644 auth/example_test.go create mode 100644 auth/go.mod create mode 100644 auth/internal/internal.go create mode 100644 auth/internal/jwt/jwt.go create mode 100644 auth/internal/jwt/jwt_test.go create mode 100644 auth/threelegged.go create mode 100644 auth/threelegged_test.go diff --git a/.release-please-manifest-individual.json b/.release-please-manifest-individual.json index 429989ced9a1..bea575724b99 100644 --- a/.release-please-manifest-individual.json +++ b/.release-please-manifest-individual.json @@ -1,4 +1,5 @@ { + "auth": "0.0.0", "bigquery": "1.54.0", "bigtable": "1.19.0", "datastore": "1.13.0", @@ -10,4 +11,4 @@ "pubsublite": "1.8.1", "spanner": "1.48.0", "storage": "1.32.0" -} +} \ No newline at end of file diff --git a/auth/README.md b/auth/README.md new file mode 100644 index 000000000000..44735e0d0526 --- /dev/null +++ b/auth/README.md @@ -0,0 +1,4 @@ +# auth + +This module is currently EXPERIMENTAL and under active development. It is not +yet indented to be used. diff --git a/auth/auth.go b/auth/auth.go new file mode 100644 index 000000000000..ff0047eb7a56 --- /dev/null +++ b/auth/auth.go @@ -0,0 +1,341 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "cloud.google.com/go/auth/internal" + "cloud.google.com/go/auth/internal/jwt" +) + +const ( + // Parameter keys for AuthCodeURL method to support PKCE. + codeChallengeKey = "code_challenge" + codeChallengeMethodKey = "code_challenge_method" + + // Parameter key for Exchange method to support PKCE. + codeVerifierKey = "code_verifier" + + defaultExpiryDelta = 10 * time.Second +) + +var ( + defaultGrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer" + defaultHeader = &jwt.Header{Algorithm: jwt.HeaderAlgRSA256, Type: jwt.HeaderType} + + // for testing + timeNow = time.Now +) + +// TokenProvider specifies an interface for anything that can return a token. +type TokenProvider interface { + // Token returns a Token or an error. + // The Token returned must be safe to use + // concurrently. + // The returned Token must not be modified. + // The context provided must be sent along to any requests that are made in + // the implementing code. + Token(context.Context) (*Token, error) +} + +// Token holds the credential token used to authorized requests. All fields are +// considered read-only. +type Token struct { + // Value is the token used to authorize requests. It is usually an access + // token but may be other types of tokens such as ID tokens in some flows. + Value string + // Type is the type of token Value is. If uninitialized, it should be + // assumed to be a "Bearer" token. + Type string + // Expiry is the time the token is set to expire. + Expiry time.Time + // Metadata may include, but is not limited to, the body of the token + // response returned by the server. + Metadata map[string]interface{} // TODO(codyoss): maybe make a method to flatten metadata to avoid []string for url.Values +} + +// IsValid reports that a [Token] is non-nil, has a [Token.Value], and has not +// expired. A token is considered expired if [Token.Expiry] has passed or will +// pass in the next 10 seconds. +func (t *Token) IsValid() bool { + return t.isValidWithEarlyExpiry(defaultExpiryDelta) +} + +func (t *Token) isValidWithEarlyExpiry(earlyExpiry time.Duration) bool { + if t == nil || t.Value == "" { + return false + } + if t.Expiry.IsZero() { + return true + } + return !t.Expiry.Round(0).Add(-earlyExpiry).Before(timeNow()) +} + +// CachedTokenProviderOptions provided options for configuring a +// CachedTokenProvider. +type CachedTokenProviderOptions struct { + // DisableAutoRefresh makes the TokenProvider always return the same token, + // even if it is expired. + DisableAutoRefresh bool + // ExpireEarly configures the amount of time before a token expires, that it + // should be refreshed. + ExpireEarly time.Duration +} + +func (ctpo *CachedTokenProviderOptions) autoRefresh() bool { + if ctpo == nil { + return true + } + return !ctpo.DisableAutoRefresh +} + +func (ctpo *CachedTokenProviderOptions) expireEarly() time.Duration { + if ctpo == nil { + return defaultExpiryDelta + } + return ctpo.ExpireEarly +} + +// May need to also pass a token for the user-auth flow? +func NewCachedTokenProvider(tp TokenProvider, opts *CachedTokenProviderOptions) TokenProvider { + if ctp, ok := tp.(*cachedTokenProvider); ok { + return ctp + } + return &cachedTokenProvider{ + tp: tp, + autoRefresh: opts.autoRefresh(), + expireEarly: opts.expireEarly(), + } +} + +type cachedTokenProvider struct { + tp TokenProvider + autoRefresh bool + expireEarly time.Duration + + mu sync.Mutex + cachedToken *Token +} + +func (c *cachedTokenProvider) Token(ctx context.Context) (*Token, error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.cachedToken.IsValid() || !c.autoRefresh { + return c.cachedToken, nil + } + t, err := c.tp.Token(ctx) + if err != nil { + return nil, err + } + c.cachedToken = t + return t, nil +} + +type Error struct { + // Response is the HTTP response associated with error. The body will always + // be already closed and consumed. + Response *http.Response + // Body is the HTTP response body. + Body []byte + // Err is the underlying wrapped error. + Err error + + // code returned in the token response + code string + // description returned in the token response + description string + // uri returned in the token response + uri string +} + +func (r *Error) Error() string { + if r.code != "" { + s := fmt.Sprintf("auth: %q", r.code) + if r.description != "" { + s += fmt.Sprintf(" %q", r.description) + } + if r.uri != "" { + s += fmt.Sprintf(" %q", r.uri) + } + return s + } + return fmt.Sprintf("auth: cannot fetch token: %v\nResponse: %s", r.Response.StatusCode, r.Body) +} + +func (e *Error) Temporary() bool { + if e.Response == nil { + return false + } + sc := e.Response.StatusCode + return sc == 500 || sc == 503 || sc == 408 || sc == 429 +} + +func (e *Error) Unwrap() error { + return e.Err +} + +// AuthStyle describes how the token endpoint wants receive the ClientID and +// ClientSecret. +type AuthStyle int + +const ( + // AuthStyleUnknown means the value has not been initiated. Sending this in + // a request will cause the token exchange to fail. + AuthStyleUnknown AuthStyle = 0 + // AuthStyleInParams sends client info in the body of a POST request. + AuthStyleInParams AuthStyle = 1 + // AuthStyleInHeader sends client info using Basic Authorization header. + AuthStyleInHeader AuthStyle = 2 +) + +// ConfigJWT2LO is the configuration settings for doing a 2-legged JWT OAuth2 flow. +type ConfigJWT2LO struct { + // Email is the OAuth2 client ID. This value is set as the "iss" in the + // JWT. + Email string + // PrivateKey contains the contents of an RSA private key or the + // contents of a PEM file that contains a private key. It is used to sign + // the JWT created. + PrivateKey []byte + // PrivateKeyID is the ID of the key used to sign the JWT. It is used as the + // "kid" in the JWT header. + PrivateKeyID string + // Subject is the used for to impersonate a user. It is used as the "sub" in + // the JWT.m Optional. + Subject string + // Scopes specifies requested permissions for the token. Optional. + Scopes []string + // TokenURL is th URL the JWT is sent to. + TokenURL string + // Expires specifies the lifetime of the token. + Expires time.Duration + // Audience specifies the "aud" in the JWT. Optional. + Audience string + // PrivateClaims allows specifying any custom claims for the JWT. Optional. + PrivateClaims map[string]interface{} + + // Client is the client to be used to make the underlying token requests. + // Optional. + Client *http.Client + // UseIDToken requests that the token returned be an ID token if one is + // returned from the server. Optional. + UseIDToken bool +} + +func (c *ConfigJWT2LO) client() *http.Client { + if c.Client != nil { + return c.Client + } + return internal.CloneDefaultClient() +} + +func (c *ConfigJWT2LO) TokenProvider() TokenProvider { + return tokenProvider2LO{c: c, Client: c.client()} +} + +type tokenProvider2LO struct { + c *ConfigJWT2LO + Client *http.Client +} + +func (tp tokenProvider2LO) Token(ctx context.Context) (*Token, error) { + pk, err := internal.ParseKey(tp.c.PrivateKey) + if err != nil { + return nil, err + } + claimSet := &jwt.Claims{ + Iss: tp.c.Email, + Scope: strings.Join(tp.c.Scopes, " "), + Aud: tp.c.TokenURL, + AdditionalClaims: tp.c.PrivateClaims, + } + if subject := tp.c.Subject; subject != "" { + claimSet.Sub = subject + } + if t := tp.c.Expires; t > 0 { + claimSet.Exp = time.Now().Add(t).Unix() + } + if aud := tp.c.Audience; aud != "" { + claimSet.Aud = aud + } + h := *defaultHeader + h.KeyID = tp.c.PrivateKeyID + payload, err := jwt.EncodeJWS(&h, claimSet, pk) + if err != nil { + return nil, err + } + v := url.Values{} + v.Set("grant_type", defaultGrantType) + v.Set("assertion", payload) + resp, err := tp.Client.PostForm(tp.c.TokenURL, v) + if err != nil { + return nil, fmt.Errorf("auth: cannot fetch token: %w", err) + } + defer resp.Body.Close() + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return nil, fmt.Errorf("auth: cannot fetch token: %w", err) + } + if c := resp.StatusCode; c < 200 || c > 299 { + return nil, &Error{ + Response: resp, + Body: body, + } + } + // tokenRes is the JSON response body. + var tokenRes struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + IDToken string `json:"id_token"` + ExpiresIn int64 `json:"expires_in"` + } + if err := json.Unmarshal(body, &tokenRes); err != nil { + return nil, fmt.Errorf("auth: cannot fetch token: %w", err) + } + token := &Token{ + Value: tokenRes.AccessToken, + Type: tokenRes.TokenType, + } + token.Metadata = make(map[string]interface{}) + json.Unmarshal(body, &token.Metadata) // no error checks for optional fields + + if secs := tokenRes.ExpiresIn; secs > 0 { + token.Expiry = time.Now().Add(time.Duration(secs) * time.Second) + } + if v := tokenRes.IDToken; v != "" { + // decode returned id token to get expiry + claimSet, err := jwt.DecodeJWS(v) + if err != nil { + return nil, fmt.Errorf("auth: error decoding JWT token: %w", err) + } + token.Expiry = time.Unix(claimSet.Exp, 0) + } + if tp.c.UseIDToken { + if tokenRes.IDToken == "" { + return nil, fmt.Errorf("auth: response doesn't have JWT token") + } + token.Value = tokenRes.IDToken + } + return token, nil +} diff --git a/auth/auth_test.go b/auth/auth_test.go new file mode 100644 index 000000000000..4be4cbb45db0 --- /dev/null +++ b/auth/auth_test.go @@ -0,0 +1,445 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "reflect" + "strings" + "testing" + "time" + + "cloud.google.com/go/auth/internal/jwt" +) + +var fakePrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEAx4fm7dngEmOULNmAs1IGZ9Apfzh+BkaQ1dzkmbUgpcoghucE +DZRnAGd2aPyB6skGMXUytWQvNYav0WTR00wFtX1ohWTfv68HGXJ8QXCpyoSKSSFY +fuP9X36wBSkSX9J5DVgiuzD5VBdzUISSmapjKm+DcbRALjz6OUIPEWi1Tjl6p5RK +1w41qdbmt7E5/kGhKLDuT7+M83g4VWhgIvaAXtnhklDAggilPPa8ZJ1IFe31lNlr +k4DRk38nc6sEutdf3RL7QoH7FBusI7uXV03DC6dwN1kP4GE7bjJhcRb/7jYt7CQ9 +/E9Exz3c0yAp0yrTg0Fwh+qxfH9dKwN52S7SBwIDAQABAoIBAQCaCs26K07WY5Jt +3a2Cw3y2gPrIgTCqX6hJs7O5ByEhXZ8nBwsWANBUe4vrGaajQHdLj5OKfsIDrOvn +2NI1MqflqeAbu/kR32q3tq8/Rl+PPiwUsW3E6Pcf1orGMSNCXxeducF2iySySzh3 +nSIhCG5uwJDWI7a4+9KiieFgK1pt/Iv30q1SQS8IEntTfXYwANQrfKUVMmVF9aIK +6/WZE2yd5+q3wVVIJ6jsmTzoDCX6QQkkJICIYwCkglmVy5AeTckOVwcXL0jqw5Kf +5/soZJQwLEyBoQq7Kbpa26QHq+CJONetPP8Ssy8MJJXBT+u/bSseMb3Zsr5cr43e +DJOhwsThAoGBAPY6rPKl2NT/K7XfRCGm1sbWjUQyDShscwuWJ5+kD0yudnT/ZEJ1 +M3+KS/iOOAoHDdEDi9crRvMl0UfNa8MAcDKHflzxg2jg/QI+fTBjPP5GOX0lkZ9g +z6VePoVoQw2gpPFVNPPTxKfk27tEzbaffvOLGBEih0Kb7HTINkW8rIlzAoGBAM9y +1yr+jvfS1cGFtNU+Gotoihw2eMKtIqR03Yn3n0PK1nVCDKqwdUqCypz4+ml6cxRK +J8+Pfdh7D+ZJd4LEG6Y4QRDLuv5OA700tUoSHxMSNn3q9As4+T3MUyYxWKvTeu3U +f2NWP9ePU0lV8ttk7YlpVRaPQmc1qwooBA/z/8AdAoGAW9x0HWqmRICWTBnpjyxx +QGlW9rQ9mHEtUotIaRSJ6K/F3cxSGUEkX1a3FRnp6kPLcckC6NlqdNgNBd6rb2rA +cPl/uSkZP42Als+9YMoFPU/xrrDPbUhu72EDrj3Bllnyb168jKLa4VBOccUvggxr +Dm08I1hgYgdN5huzs7y6GeUCgYEAj+AZJSOJ6o1aXS6rfV3mMRve9bQ9yt8jcKXw +5HhOCEmMtaSKfnOF1Ziih34Sxsb7O2428DiX0mV/YHtBnPsAJidL0SdLWIapBzeg +KHArByIRkwE6IvJvwpGMdaex1PIGhx5i/3VZL9qiq/ElT05PhIb+UXgoWMabCp84 +OgxDK20CgYAeaFo8BdQ7FmVX2+EEejF+8xSge6WVLtkaon8bqcn6P0O8lLypoOhd +mJAYH8WU+UAy9pecUnDZj14LAGNVmYcse8HFX71MoshnvCTFEPVo4rZxIAGwMpeJ +5jgQ3slYLpqrGlcbLgUXBUgzEO684Wk/UV9DFPlHALVqCfXQ9dpJPg== +-----END RSA PRIVATE KEY-----`) + +func TestError_Temporary(t *testing.T) { + tests := []struct { + name string + code int + want bool + }{ + { + name: "temporary with 500", + code: 500, + want: true, + }, + { + name: "temporary with 503", + code: 503, + want: true, + }, + { + name: "temporary with 408", + code: 408, + want: true, + }, + { + name: "temporary with 429", + code: 429, + want: true, + }, + { + name: "temporary with 418", + code: 418, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ae := &Error{ + Response: &http.Response{ + StatusCode: tt.code, + }, + } + if got := ae.Temporary(); got != tt.want { + t.Errorf("Temporary() = %v; want %v", got, tt.want) + } + }) + } +} + +func TestToken_isValidWithEarlyExpiry(t *testing.T) { + now := time.Now() + timeNow = func() time.Time { return now } + defer func() { timeNow = time.Now }() + + cases := []struct { + name string + tok *Token + expiry time.Duration + want bool + }{ + {name: "12 seconds", tok: &Token{Expiry: now.Add(12 * time.Second)}, expiry: defaultExpiryDelta, want: true}, + {name: "10 seconds", tok: &Token{Expiry: now.Add(defaultExpiryDelta)}, expiry: defaultExpiryDelta, want: true}, + {name: "10 seconds-1ns", tok: &Token{Expiry: now.Add(defaultExpiryDelta - 1*time.Nanosecond)}, expiry: defaultExpiryDelta, want: false}, + {name: "-1 hour", tok: &Token{Expiry: now.Add(-1 * time.Hour)}, expiry: defaultExpiryDelta, want: false}, + {name: "12 seconds, custom expiryDelta", tok: &Token{Expiry: now.Add(12 * time.Second)}, expiry: time.Second * 5, want: true}, + {name: "5 seconds, custom expiryDelta", tok: &Token{Expiry: now.Add(time.Second * 5)}, expiry: time.Second * 5, want: true}, + {name: "5 seconds-1ns, custom expiryDelta", tok: &Token{Expiry: now.Add(time.Second*5 - 1*time.Nanosecond)}, expiry: time.Second * 5, want: false}, + {name: "-1 hour, custom expiryDelta", tok: &Token{Expiry: now.Add(-1 * time.Hour)}, expiry: time.Second * 5, want: false}, + } + for _, tc := range cases { + tc.tok.Value = "tok" + if got, want := tc.tok.isValidWithEarlyExpiry(tc.expiry), tc.want; got != want { + t.Errorf("expired (%q) = %v; want %v", tc.name, got, want) + } + } +} + +func TestError_Error(t *testing.T) { + + tests := []struct { + name string + + Response *http.Response + Body []byte + Err error + code string + description string + uri string + + want string + }{ + { + name: "basic", + Response: &http.Response{ + StatusCode: 418, + }, + Body: []byte("I'm a teapot"), + want: "auth: cannot fetch token: 418\nResponse: I'm a teapot", + }, + { + name: "from query", + code: "418", + description: "I'm a teapot", + uri: "somewhere", + want: "auth: \"418\" \"I'm a teapot\" \"somewhere\"", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &Error{ + Response: tt.Response, + Body: tt.Body, + Err: tt.Err, + code: tt.code, + description: tt.description, + uri: tt.uri, + } + if got := r.Error(); got != tt.want { + t.Errorf("Error.Error() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestConfigJWT2LO_JSONResponse(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{ + "access_token": "90d64460d14870c08c81352a05dedd3465940a7c", + "scope": "user", + "token_type": "bearer", + "expires_in": 3600 + }`)) + })) + defer ts.Close() + + conf := &ConfigJWT2LO{ + Email: "aaa@xxx.com", + PrivateKey: fakePrivateKey, + TokenURL: ts.URL, + } + tok, err := conf.TokenProvider().Token(context.Background()) + if err != nil { + t.Fatal(err) + } + if !tok.IsValid() { + t.Errorf("got invalid token: %v", tok) + } + if got, want := tok.Value, "90d64460d14870c08c81352a05dedd3465940a7c"; got != want { + t.Errorf("access token = %q; want %q", got, want) + } + if got, want := tok.Type, "bearer"; got != want { + t.Errorf("token type = %q; want %q", got, want) + } + if got := tok.Expiry.IsZero(); got { + t.Errorf("token expiry = %v, want none", got) + } + scope := tok.Metadata["scope"].(string) + if got, want := scope, "user"; got != want { + t.Errorf("scope = %q; want %q", got, want) + } +} + +func TestConfigJWT2LO_BadResponse(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"scope": "user", "token_type": "bearer"}`)) + })) + defer ts.Close() + + conf := &ConfigJWT2LO{ + Email: "aaa@xxx.com", + PrivateKey: fakePrivateKey, + TokenURL: ts.URL, + } + tok, err := conf.TokenProvider().Token(context.Background()) + if err != nil { + t.Fatal(err) + } + if tok == nil { + t.Fatalf("got nil token; want token") + } + if tok.IsValid() { + t.Errorf("got invalid token: %v", tok) + } + if got, want := tok.Value, ""; got != want { + t.Errorf("access token = %q; want %q", got, want) + } + if got, want := tok.Type, "bearer"; got != want { + t.Errorf("token type = %q; want %q", got, want) + } + scope := tok.Metadata["scope"].(string) + if got, want := scope, "user"; got != want { + t.Errorf("token scope = %q; want %q", got, want) + } +} + +func TestConfigJWT2LO_BadResponseType(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"access_token":123, "scope": "user", "token_type": "bearer"}`)) + })) + defer ts.Close() + conf := &ConfigJWT2LO{ + Email: "aaa@xxx.com", + PrivateKey: fakePrivateKey, + TokenURL: ts.URL, + } + tok, err := conf.TokenProvider().Token(context.Background()) + if err == nil { + t.Error("got a token; expected error") + if got, want := tok.Value, ""; got != want { + t.Errorf("access token = %q; want %q", got, want) + } + } +} + +func TestConfigJWT2LO_Assertion(t *testing.T) { + var assertion string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + assertion = r.Form.Get("assertion") + + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{ + "access_token": "90d64460d14870c08c81352a05dedd3465940a7c", + "scope": "user", + "token_type": "bearer", + "expires_in": 3600 + }`)) + })) + defer ts.Close() + + conf := &ConfigJWT2LO{ + Email: "aaa@xxx.com", + PrivateKey: fakePrivateKey, + PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ", + TokenURL: ts.URL, + } + + _, err := conf.TokenProvider().Token(context.Background()) + if err != nil { + t.Fatalf("Failed to fetch token: %v", err) + } + + parts := strings.Split(assertion, ".") + if len(parts) != 3 { + t.Fatalf("assertion = %q; want 3 parts", assertion) + } + gotjson, err := base64.RawURLEncoding.DecodeString(parts[0]) + if err != nil { + t.Fatalf("invalid token header; err = %v", err) + } + + got := jwt.Header{} + if err := json.Unmarshal(gotjson, &got); err != nil { + t.Errorf("failed to unmarshal json token header = %q; err = %v", gotjson, err) + } + + want := jwt.Header{ + Algorithm: "RS256", + Type: "JWT", + KeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ", + } + if got != want { + t.Errorf("access token header = %q; want %q", got, want) + } +} + +func TestConfigJWT2LO_AssertionPayload(t *testing.T) { + var assertion string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + assertion = r.Form.Get("assertion") + + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{ + "access_token": "90d64460d14870c08c81352a05dedd3465940a7c", + "scope": "user", + "token_type": "bearer", + "expires_in": 3600 + }`)) + })) + defer ts.Close() + + for _, conf := range []*ConfigJWT2LO{ + { + Email: "aaa1@xxx.com", + PrivateKey: fakePrivateKey, + PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ", + TokenURL: ts.URL, + }, + { + Email: "aaa2@xxx.com", + PrivateKey: fakePrivateKey, + PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ", + TokenURL: ts.URL, + Audience: "https://example.com", + }, + { + Email: "aaa2@xxx.com", + PrivateKey: fakePrivateKey, + PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ", + TokenURL: ts.URL, + PrivateClaims: map[string]interface{}{ + "private0": "claim0", + "private1": "claim1", + }, + }, + } { + t.Run(conf.Email, func(t *testing.T) { + _, err := conf.TokenProvider().Token(context.Background()) + if err != nil { + t.Fatalf("Failed to fetch token: %v", err) + } + + parts := strings.Split(assertion, ".") + if len(parts) != 3 { + t.Fatalf("assertion = %q; want 3 parts", assertion) + } + gotjson, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + t.Fatalf("invalid token payload; err = %v", err) + } + + claimSet := jwt.Claims{} + if err := json.Unmarshal(gotjson, &claimSet); err != nil { + t.Errorf("failed to unmarshal json token payload = %q; err = %v", gotjson, err) + } + + if got, want := claimSet.Iss, conf.Email; got != want { + t.Errorf("payload email = %q; want %q", got, want) + } + if got, want := claimSet.Scope, strings.Join(conf.Scopes, " "); got != want { + t.Errorf("payload scope = %q; want %q", got, want) + } + aud := conf.TokenURL + if conf.Audience != "" { + aud = conf.Audience + } + if got, want := claimSet.Aud, aud; got != want { + t.Errorf("payload audience = %q; want %q", got, want) + } + if got, want := claimSet.Sub, conf.Subject; got != want { + t.Errorf("payload subject = %q; want %q", got, want) + } + if len(conf.PrivateClaims) > 0 { + var got interface{} + if err := json.Unmarshal(gotjson, &got); err != nil { + t.Errorf("failed to parse payload; err = %q", err) + } + m := got.(map[string]interface{}) + for v, k := range conf.PrivateClaims { + if !reflect.DeepEqual(m[v], k) { + t.Errorf("payload private claims key = %q: got %#v; want %#v", v, m[v], k) + } + } + } + }) + } +} + +func TestConfigJWT2LO_TokenError(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-type", "application/json") + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{"error": "invalid_grant"}`)) + })) + defer ts.Close() + + conf := &ConfigJWT2LO{ + Email: "aaa@xxx.com", + PrivateKey: fakePrivateKey, + TokenURL: ts.URL, + } + + _, err := conf.TokenProvider().Token(context.Background()) + if err == nil { + t.Fatalf("got no error, expected one") + } + _, ok := err.(*Error) + if !ok { + t.Fatalf("got %T error, expected *Error", err) + } + expected := fmt.Sprintf("auth: cannot fetch token: %v\nResponse: %s", "400", `{"error": "invalid_grant"}`) + if errStr := err.Error(); errStr != expected { + t.Fatalf("got %#v, expected %#v", errStr, expected) + } +} diff --git a/auth/example_test.go b/auth/example_test.go new file mode 100644 index 000000000000..bf6d733ae3c2 --- /dev/null +++ b/auth/example_test.go @@ -0,0 +1,58 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth_test + +import ( + "cloud.google.com/go/auth" +) + +func ExampleConfigJWT2LO() { + // Your credentials should be obtained from the Google + // Developer Console (https://console.developers.google.com). + conf := &auth.ConfigJWT2LO{ + Email: "xxx@developer.gserviceaccount.com", + // The contents of your RSA private key or your PEM file + // that contains a private key. + // If you have a p12 file instead, you + // can use `openssl` to export the private key into a pem file. + // + // $ openssl pkcs12 -in key.p12 -passin pass:notasecret -out key.pem -nodes + // + // The field only supports PEM containers with no passphrase. + // The openssl command will convert p12 keys to passphrase-less PEM containers. + PrivateKey: []byte("-----BEGIN RSA PRIVATE KEY-----..."), + Scopes: []string{ + "https://www.googleapis.com/auth/bigquery", + "https://www.googleapis.com/auth/blogger", + }, + TokenURL: "https://oauth2.googleapis.com/token", + // If you would like to impersonate a user, you can + // create a transport with a subject. The following GET + // request will be made on the behalf of user@example.com. + // Optional. + Subject: "user@example.com", + } + + tp := conf.TokenProvider() + // TODO(codyoss): Fixup once more code is merged + // client, err := httptransport.NewClient(&httptransport.Options{ + // TokenProvider: tp, + // }) + // if err != nil { + // log.Fatal(err) + // } + // client.Get("...") + _ = tp +} diff --git a/auth/go.mod b/auth/go.mod new file mode 100644 index 000000000000..33b4c07d9867 --- /dev/null +++ b/auth/go.mod @@ -0,0 +1,3 @@ +module cloud.google.com/go/auth + +go 1.19 diff --git a/auth/internal/internal.go b/auth/internal/internal.go new file mode 100644 index 000000000000..c0dfbb80813f --- /dev/null +++ b/auth/internal/internal.go @@ -0,0 +1,112 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "crypto/rsa" + "crypto/x509" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "net/http" + "os" + "time" +) + +const ( + TokenTypeBearer = "Bearer" + + quotaProjectEnvVar = "GOOGLE_CLOUD_QUOTA_PROJECT" + projectEnvVar = "GOOGLE_CLOUD_PROJECT" +) + +// CloneDefaultClient returns a [http.Client] with some good defaults. +func CloneDefaultClient() *http.Client { + return &http.Client{ + Transport: http.DefaultTransport.(*http.Transport).Clone(), + Timeout: 30 * time.Second, + } +} + +// ParseKey converts the binary contents of a private key file +// to an *rsa.PrivateKey. It detects whether the private key is in a +// PEM container or not. If so, it extracts the the private key +// from PEM container before conversion. It only supports PEM +// containers with no passphrase. +func ParseKey(key []byte) (*rsa.PrivateKey, error) { + block, _ := pem.Decode(key) + if block != nil { + key = block.Bytes + } + parsedKey, err := x509.ParsePKCS8PrivateKey(key) + if err != nil { + parsedKey, err = x509.ParsePKCS1PrivateKey(key) + if err != nil { + return nil, fmt.Errorf("private key should be a PEM or plain PKCS1 or PKCS8: %w", err) + } + } + parsed, ok := parsedKey.(*rsa.PrivateKey) + if !ok { + return nil, errors.New("private key is invalid") + } + return parsed, nil +} + +// GetQuotaProject retrieves quota project with precedence being: override, +// environment variable, creds json file. +func GetQuotaProject(b []byte, override string) string { + if override != "" { + return override + } + if env := os.Getenv(quotaProjectEnvVar); env != "" { + return env + } + if b == nil { + return "" + } + var v struct { + QuotaProject string `json:"quota_project_id"` + } + if err := json.Unmarshal(b, &v); err != nil { + return "" + } + return v.QuotaProject +} + +// GetProjectID retrieves project with precedence being: override, +// environment variable, creds json file. +func GetProjectID(b []byte, override string) string { + if override != "" { + return override + } + if env := os.Getenv(projectEnvVar); env != "" { + return env + } + if b == nil { + return "" + } + var v struct { + ProjectID string `json:"project_id"` // standard service account key + Project string `json:"project"` // gdch key + } + if err := json.Unmarshal(b, &v); err != nil { + return "" + } + if v.ProjectID != "" { + return v.ProjectID + } + return v.Project +} diff --git a/auth/internal/jwt/jwt.go b/auth/internal/jwt/jwt.go new file mode 100644 index 000000000000..a07db35df924 --- /dev/null +++ b/auth/internal/jwt/jwt.go @@ -0,0 +1,166 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package jwt + +import ( + "bytes" + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "strings" + "time" +) + +const ( + HeaderAlgRSA256 = "RS256" + HeaderAlgES256 = "ES256" + HeaderType = "JWT" +) + +type Header struct { + Algorithm string `json:"alg"` + Type string `json:"typ"` + KeyID string `json:"kid"` +} + +func (h *Header) encode() (string, error) { + b, err := json.Marshal(h) + if err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +type Claims struct { + // Iss is the issuer JWT claim. + Iss string `json:"iss"` + // Scope is the scope JWT claim. + Scope string `json:"scope,omitempty"` + // Exp is the expiry JWT claim. + Exp int64 `json:"exp"` + // Iat is the subject issued at claim. + Iat int64 `json:"iat"` + // Aud is the audience JWT claim. Optional. + Aud string `json:"aud"` + // Sub is the subject JWT claim. Optional. + Sub string `json:"sub,omitempty"` + // AdditionalClaims contains any additional non-standard JWT claims. Optional. + AdditionalClaims map[string]interface{} `json:"-"` +} + +func (c *Claims) encode() (string, error) { + // Compensate for skew + now := time.Now().Add(-10 * time.Second) + if c.Iat == 0 { + c.Iat = now.Unix() + } + if c.Exp == 0 { + c.Exp = now.Add(time.Hour).Unix() + } + if c.Exp < c.Iat { + return "", fmt.Errorf("jwt: invalid Exp = %d; must be later than Iat = %d", c.Exp, c.Iat) + } + + b, err := json.Marshal(c) + if err != nil { + return "", err + } + + if len(c.AdditionalClaims) == 0 { + return base64.RawURLEncoding.EncodeToString(b), nil + } + + // Marshal private claim set and then append it to b. + prv, err := json.Marshal(c.AdditionalClaims) + if err != nil { + return "", fmt.Errorf("invalid map of additional claims %v", c.AdditionalClaims) + } + + // Concatenate public and private claim JSON objects. + if !bytes.HasSuffix(b, []byte{'}'}) { + return "", fmt.Errorf("invalid JSON %s", b) + } + if !bytes.HasPrefix(prv, []byte{'{'}) { + return "", fmt.Errorf("invalid JSON %s", prv) + } + b[len(b)-1] = ',' // Replace closing curly brace with a comma. + b = append(b, prv[1:]...) // Append private claims. + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// EncodeJWS encodes the data using the provided key as a JSON web signature. +func EncodeJWS(header *Header, c *Claims, key *rsa.PrivateKey) (string, error) { + head, err := header.encode() + if err != nil { + return "", err + } + claims, err := c.encode() + if err != nil { + return "", err + } + ss := fmt.Sprintf("%s.%s", head, claims) + h := sha256.New() + h.Write([]byte(ss)) + sig, err := rsa.SignPKCS1v15(rand.Reader, key, crypto.SHA256, h.Sum(nil)) + if err != nil { + return "", err + } + return fmt.Sprintf("%s.%s", ss, base64.RawURLEncoding.EncodeToString(sig)), nil +} + +// DecodeJWS decodes a claim set from a JWS payload. +func DecodeJWS(payload string) (*Claims, error) { + // decode returned id token to get expiry + s := strings.Split(payload, ".") + if len(s) < 2 { + return nil, errors.New("invalid token received") + } + decoded, err := base64.RawURLEncoding.DecodeString(s[1]) + if err != nil { + return nil, err + } + c := &Claims{} + if err := json.NewDecoder(bytes.NewBuffer(decoded)).Decode(c); err != nil { + return nil, err + } + if err := json.NewDecoder(bytes.NewBuffer(decoded)).Decode(&c.AdditionalClaims); err != nil { + return nil, err + } + return c, err +} + +// VerifyJWS tests whether the provided JWT token's signature was produced by +// the private key associated with the provided public key. +func VerifyJWS(token string, key *rsa.PublicKey) error { + parts := strings.Split(token, ".") + if len(parts) != 3 { + return errors.New("jwt: invalid token received, token must have 3 parts") + } + + signedContent := parts[0] + "." + parts[1] + signatureString, err := base64.RawURLEncoding.DecodeString(parts[2]) + if err != nil { + return err + } + + h := sha256.New() + h.Write([]byte(signedContent)) + return rsa.VerifyPKCS1v15(key, crypto.SHA256, h.Sum(nil), signatureString) +} diff --git a/auth/internal/jwt/jwt_test.go b/auth/internal/jwt/jwt_test.go new file mode 100644 index 000000000000..0993695b5716 --- /dev/null +++ b/auth/internal/jwt/jwt_test.go @@ -0,0 +1,79 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package jwt + +import ( + "crypto/rand" + "crypto/rsa" + "testing" +) + +func TestSignAndVerifyDecode(t *testing.T) { + header := &Header{ + Algorithm: "RS256", + Type: "JWT", + } + payload := &Claims{ + Iss: "http://google.com/", + Aud: "", + Exp: 3610, + Iat: 10, + AdditionalClaims: map[string]interface{}{ + "foo": "bar", + }, + } + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatal(err) + } + + token, err := EncodeJWS(header, payload, privateKey) + if err != nil { + t.Fatal(err) + } + + if err := VerifyJWS(token, &privateKey.PublicKey); err != nil { + t.Fatal(err) + } + + claims, err := DecodeJWS(token) + if err != nil { + t.Fatal(err) + } + + if claims.Iss != payload.Iss { + t.Errorf("got %q, want %q", claims.Iss, payload.Iss) + } + if claims.Aud != payload.Aud { + t.Errorf("got %q, want %q", claims.Aud, payload.Aud) + } + if claims.Exp != payload.Exp { + t.Errorf("got %d, want %d", claims.Exp, payload.Exp) + } + if claims.Iat != payload.Iat { + t.Errorf("got %d, want %d", claims.Iat, payload.Iat) + } + if claims.AdditionalClaims["foo"] != payload.AdditionalClaims["foo"] { + t.Errorf("got %q, want %q", claims.AdditionalClaims["foo"], payload.AdditionalClaims["foo"]) + } +} + +func TestVerifyFailsOnMalformedClaim(t *testing.T) { + err := VerifyJWS("abc.def", nil) + if err == nil { + t.Error("got no errors; want improperly formed JWT not to be verified") + } +} diff --git a/auth/threelegged.go b/auth/threelegged.go new file mode 100644 index 000000000000..f29776d3a6e8 --- /dev/null +++ b/auth/threelegged.go @@ -0,0 +1,335 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "mime" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "cloud.google.com/go/auth/internal" +) + +// AuthorizationHandler is a 3-legged-OAuth helper that prompts the user for +// OAuth consent at the specified auth code URL and returns an auth code and +// state upon approval. +type AuthorizationHandler func(authCodeURL string) (code string, state string, err error) + +// Config3LO is the configuration settings for doing a 3-legged OAuth2 flow. +type Config3LO struct { + // ClientID is the application's ID. + ClientID string + // ClientSecret is the application's secret. + ClientSecret string + // AuthURL is the URL for authenticating. + AuthURL string + // TokenURL is the URL for retrieving a token. + TokenURL string + // RedirectURL is the URL to redirect users to. + RedirectURL string + // Scopes specifies requested permissions for the Token. Optional. + Scopes []string + + // URLParams are the set of values to apply to the token exchange. Optional. + URLParams url.Values + // Client is the client to be used to make the underlying token requests. + // Optional. + Client *http.Client + // AuthStyle is used to describe how to client info in the token request. + AuthStyle AuthStyle + // EarlyTokenExpiry is the time before the token expires that it should be + // refreshed. If not set the default value is 10 seconds. Optional. + EarlyTokenExpiry time.Duration + + pkceConf *PKCEConfig +} + +// PKCEParams holds parameters to support PKCE. +type PKCEConfig struct { + // Challenge is the un-padded, base64-url-encoded string of the encrypted code verifier. + Challenge string // The un-padded, base64-url-encoded string of the encrypted code verifier. + // ChallengeMethod is the encryption method (ex. S256). + ChallengeMethod string + // Verifier is the original, non-encrypted secret. + Verifier string // The original, non-encrypted secret. +} + +type tokenJSON struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int `json:"expires_in"` + // error fields + ErrorCode string `json:"error"` + ErrorDescription string `json:"error_description"` + ErrorURI string `json:"error_uri"` +} + +func (e *tokenJSON) expiry() (t time.Time) { + if v := e.ExpiresIn; v != 0 { + return time.Now().Add(time.Duration(v) * time.Second) + } + return +} + +func (c *Config3LO) client() *http.Client { + if c.Client != nil { + return c.Client + } + return internal.CloneDefaultClient() +} + +// authCodeURL returns a URL that points to a OAuth2 consent page. +func (c *Config3LO) authCodeURL(state string, values url.Values) string { + var buf bytes.Buffer + buf.WriteString(c.AuthURL) + v := url.Values{ + "response_type": {"code"}, + "client_id": {c.ClientID}, + } + if c.RedirectURL != "" { + v.Set("redirect_uri", c.RedirectURL) + } + if len(c.Scopes) > 0 { + v.Set("scope", strings.Join(c.Scopes, " ")) + } + if state != "" { + v.Set("state", state) + } + if c.pkceConf != nil && c.pkceConf.Challenge != "" && c.pkceConf.ChallengeMethod != "" { + v.Set(codeChallengeKey, c.pkceConf.Challenge) + } + if c.pkceConf != nil && c.pkceConf.ChallengeMethod != "" { + v.Set(codeChallengeMethodKey, c.pkceConf.ChallengeMethod) + } + for k := range values { + v.Set(k, v.Get(k)) + } + if strings.Contains(c.AuthURL, "?") { + buf.WriteByte('&') + } else { + buf.WriteByte('?') + } + buf.WriteString(v.Encode()) + return buf.String() +} + +// TokenProvider returns a TokenProvider based the the 3-legged OAuth2 +// configuration. The TokenProvider is caches and auto-refreshes tokens by +// default. +func (c *Config3LO) TokenProvider(refreshToken string) TokenProvider { + return NewCachedTokenProvider(&tokenProvider3LO{config: c, refreshToken: refreshToken, client: c.client()}, &CachedTokenProviderOptions{ + ExpireEarly: c.EarlyTokenExpiry, + }) +} + +// AuthHandlerOptions provides a set of options to specify for doing a +// 3-legged OAuth2 flow with a custom [AuthorizationHandler]. +type AuthHandlerOptions struct { + // AuthorizationHandler specifies the handler used to for the authorization + // part of the flow. + AuthorizationHandler AuthorizationHandler + // State is used verify that the "state" is identical in the request and + // response before exchanging the auth code for OAuth2 token. + State string + // PKCEConfig allows setting configurations for PKCE. Optional. + PKCEConfig *PKCEConfig +} + +func (c *Config3LO) TokenProviderWithAuthHandler(opts AuthHandlerOptions) TokenProvider { + c.pkceConf = opts.PKCEConfig + return NewCachedTokenProvider(&tokenProviderWithHandler{c: c, handler: opts.AuthorizationHandler, state: opts.State}, &CachedTokenProviderOptions{ + ExpireEarly: c.EarlyTokenExpiry, + }) +} + +// exchange handles the final exchange portion of the 3lo flow. Returns a Token, +// refreshToken, and error. +func (c *Config3LO) exchange(ctx context.Context, code string) (*Token, string, error) { + // Build request + v := url.Values{ + "grant_type": {"authorization_code"}, + "code": {code}, + } + if c.RedirectURL != "" { + v.Set("redirect_uri", c.RedirectURL) + } + if c.pkceConf != nil && c.pkceConf.Verifier != "" { + v.Set(codeVerifierKey, c.pkceConf.Verifier) + } + for k := range c.URLParams { + v.Set(k, c.URLParams.Get(k)) + } + return fetchToken(ctx, c, v) +} + +// This struct is not safe for concurrent access alone, but the way it is used +// in this package by wrapping it with a cachedTokenProvider makes it so. +type tokenProvider3LO struct { + config *Config3LO + client *http.Client + refreshToken string +} + +func (tp *tokenProvider3LO) Token(ctx context.Context) (*Token, error) { + if tp.refreshToken == "" { + return nil, errors.New("auth: token expired and refresh token is not set") + } + v := url.Values{ + "grant_type": {"refresh_token"}, + "refresh_token": {tp.refreshToken}, + } + for k := range tp.config.URLParams { + v.Set(k, tp.config.URLParams.Get(k)) + } + + tk, rt, err := fetchToken(ctx, tp.config, v) + if err != nil { + return nil, err + } + if tp.refreshToken != rt && rt != "" { + tp.refreshToken = rt + } + return tk, err +} + +type tokenProviderWithHandler struct { + c *Config3LO + handler AuthorizationHandler + state string +} + +func (tp tokenProviderWithHandler) Token(ctx context.Context) (*Token, error) { + url := tp.c.authCodeURL(tp.state, nil) + code, state, err := tp.handler(url) + if err != nil { + return nil, err + } + if state != tp.state { + return nil, errors.New("auth: state mismatch in 3-legged-OAuth flow") + } + tok, _, err := tp.c.exchange(ctx, code) + return tok, err +} + +// fetchToken returns a Token, refresh token, and/or an error. +func fetchToken(ctx context.Context, c *Config3LO, v url.Values) (*Token, string, error) { + var refreshToken string + if c.AuthStyle == AuthStyleUnknown { + return nil, refreshToken, fmt.Errorf("auth: missing required field AuthStyle") + } + if c.AuthStyle == AuthStyleInParams { + if c.ClientID != "" { + v.Set("client_id", c.ClientID) + } + if c.ClientSecret != "" { + v.Set("client_secret", c.ClientSecret) + } + } + req, err := http.NewRequest("POST", c.TokenURL, strings.NewReader(v.Encode())) + if err != nil { + return nil, refreshToken, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + if c.AuthStyle == AuthStyleInHeader { + req.SetBasicAuth(url.QueryEscape(c.ClientID), url.QueryEscape(c.ClientSecret)) + } + + // Make request + r, err := c.client().Do(req.WithContext(ctx)) + if err != nil { + return nil, refreshToken, err + } + body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20)) + r.Body.Close() + if err != nil { + return nil, refreshToken, fmt.Errorf("auth: cannot fetch token: %w", err) + } + + failureStatus := r.StatusCode < 200 || r.StatusCode > 299 + tokError := &Error{ + Response: r, + Body: body, + } + + var token *Token + content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type")) + switch content { + case "application/x-www-form-urlencoded", "text/plain": + // some endpoints return a query string + vals, err := url.ParseQuery(string(body)) + if err != nil { + if failureStatus { + return nil, refreshToken, tokError + } + return nil, refreshToken, fmt.Errorf("auth: cannot parse response: %w", err) + } + tokError.code = vals.Get("error") + tokError.description = vals.Get("error_description") + tokError.uri = vals.Get("error_uri") + token = &Token{ + Value: vals.Get("access_token"), + Type: vals.Get("token_type"), + Metadata: make(map[string]interface{}, len(vals)), + } + for k, v := range vals { + token.Metadata[k] = v + } + refreshToken = vals.Get("refresh_token") + e := vals.Get("expires_in") + expires, _ := strconv.Atoi(e) + if expires != 0 { + token.Expiry = time.Now().Add(time.Duration(expires) * time.Second) + } + default: + var tj tokenJSON + if err = json.Unmarshal(body, &tj); err != nil { + if failureStatus { + return nil, refreshToken, tokError + } + return nil, refreshToken, fmt.Errorf("auth: cannot parse json: %w", err) + } + tokError.code = tj.ErrorCode + tokError.description = tj.ErrorDescription + tokError.uri = tj.ErrorURI + token = &Token{ + Value: tj.AccessToken, + Type: tj.TokenType, + Expiry: tj.expiry(), + Metadata: make(map[string]interface{}), + } + json.Unmarshal(body, &token.Metadata) // optional field, skip err check + refreshToken = tj.RefreshToken + } + // according to spec, servers should respond status 400 in error case + // https://www.rfc-editor.org/rfc/rfc6749#section-5.2 + // but some unorthodox servers respond 200 in error case + if failureStatus || tokError.code != "" { + return nil, refreshToken, tokError + } + if token.Value == "" { + return nil, refreshToken, errors.New("auth: server response missing access_token") + } + return token, refreshToken, nil +} diff --git a/auth/threelegged_test.go b/auth/threelegged_test.go new file mode 100644 index 000000000000..018d9d3f09eb --- /dev/null +++ b/auth/threelegged_test.go @@ -0,0 +1,475 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" +) + +const day = 24 * time.Hour + +type mockTransport struct { + rt func(req *http.Request) (resp *http.Response, err error) +} + +func (t *mockTransport) RoundTrip(req *http.Request) (resp *http.Response, err error) { + return t.rt(req) +} + +func newConf(url string) *Config3LO { + return &Config3LO{ + ClientID: "CLIENT_ID", + ClientSecret: "CLIENT_SECRET", + RedirectURL: "REDIRECT_URL", + Scopes: []string{"scope1", "scope2"}, + AuthURL: url + "/auth", + TokenURL: url + "/token", + AuthStyle: AuthStyleInHeader, + } +} + +func TestConfig3LO_URLUnsafe(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got, want := r.Header.Get("Authorization"), "Basic Q0xJRU5UX0lEJTNGJTNGOkNMSUVOVF9TRUNSRVQlM0YlM0Y="; got != want { + t.Errorf("Authorization header = %q; want %q", got, want) + } + + w.Header().Set("Content-Type", "application/x-www-form-urlencoded") + w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer")) + })) + defer ts.Close() + conf := newConf(ts.URL) + conf.ClientID = "CLIENT_ID??" + conf.ClientSecret = "CLIENT_SECRET??" + _, _, err := conf.exchange(context.Background(), "exchange-code") + if err != nil { + t.Error(err) + } +} + +func TestConfig3LO_StandardExchange(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.String() != "/token" { + t.Errorf("Unexpected exchange request URL %q", r.URL) + } + headerAuth := r.Header.Get("Authorization") + if want := "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ="; headerAuth != want { + t.Errorf("Unexpected authorization header %q, want %q", headerAuth, want) + } + headerContentType := r.Header.Get("Content-Type") + if headerContentType != "application/x-www-form-urlencoded" { + t.Errorf("Unexpected Content-Type header %q", headerContentType) + } + body, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("Failed reading request body: %s.", err) + } + if string(body) != "code=exchange-code&grant_type=authorization_code&redirect_uri=REDIRECT_URL" { + t.Errorf("Unexpected exchange payload; got %q", body) + } + w.Header().Set("Content-Type", "application/x-www-form-urlencoded") + w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer")) + })) + defer ts.Close() + conf := newConf(ts.URL) + tok, _, err := conf.exchange(context.Background(), "exchange-code") + if err != nil { + t.Error(err) + } + if !tok.IsValid() { + t.Fatalf("Token invalid. Got: %#v", tok) + } + if tok.Value != "90d64460d14870c08c81352a05dedd3465940a7c" { + t.Errorf("Unexpected access token, %#v.", tok.Value) + } + if tok.Type != "bearer" { + t.Errorf("Unexpected token type, %#v.", tok.Type) + } + scope := tok.Metadata["scope"].([]string) + if scope[0] != "user" { + t.Errorf("Unexpected value for scope: %v", scope) + } +} + +func TestConfig3LO_ExchangeCustomParams(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.String() != "/token" { + t.Errorf("Unexpected exchange request URL, %v is found.", r.URL) + } + headerAuth := r.Header.Get("Authorization") + if headerAuth != "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=" { + t.Errorf("Unexpected authorization header, %v is found.", headerAuth) + } + headerContentType := r.Header.Get("Content-Type") + if headerContentType != "application/x-www-form-urlencoded" { + t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType) + } + body, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("Failed reading request body: %s.", err) + } + if string(body) != "code=exchange-code&foo=bar&grant_type=authorization_code&redirect_uri=REDIRECT_URL" { + t.Errorf("Unexpected exchange payload, %v is found.", string(body)) + } + w.Header().Set("Content-Type", "application/x-www-form-urlencoded") + w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer")) + })) + defer ts.Close() + conf := newConf(ts.URL) + conf.URLParams = url.Values{} + conf.URLParams.Set("foo", "bar") + + tok, _, err := conf.exchange(context.Background(), "exchange-code") + if err != nil { + t.Error(err) + } + if !tok.IsValid() { + t.Fatalf("Token invalid. Got: %#v", tok) + } + if tok.Value != "90d64460d14870c08c81352a05dedd3465940a7c" { + t.Errorf("Unexpected access token, %#v.", tok.Value) + } + if tok.Type != "bearer" { + t.Errorf("Unexpected token type, %#v.", tok.Type) + } + scope := tok.Metadata["scope"].([]string) + if scope[0] != "user" { + t.Errorf("Unexpected value for scope: %v", scope) + } +} + +func TestConfig3LO_ExchangeJSONResponse(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.String() != "/token" { + t.Errorf("Unexpected exchange request URL, %v is found.", r.URL) + } + headerAuth := r.Header.Get("Authorization") + if headerAuth != "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=" { + t.Errorf("Unexpected authorization header, %v is found.", headerAuth) + } + headerContentType := r.Header.Get("Content-Type") + if headerContentType != "application/x-www-form-urlencoded" { + t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType) + } + body, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("Failed reading request body: %s.", err) + } + if string(body) != "code=exchange-code&grant_type=authorization_code&redirect_uri=REDIRECT_URL" { + t.Errorf("Unexpected exchange payload, %v is found.", string(body)) + } + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"access_token": "90d64460d14870c08c81352a05dedd3465940a7c", "scope": "user", "token_type": "bearer", "expires_in": 86400}`)) + })) + defer ts.Close() + conf := newConf(ts.URL) + tok, _, err := conf.exchange(context.Background(), "exchange-code") + if err != nil { + t.Error(err) + } + if !tok.IsValid() { + t.Fatalf("Token invalid. Got: %#v", tok) + } + if tok.Value != "90d64460d14870c08c81352a05dedd3465940a7c" { + t.Errorf("Unexpected access token, %#v.", tok.Value) + } + if tok.Type != "bearer" { + t.Errorf("Unexpected token type, %#v.", tok.Type) + } + scope := tok.Metadata["scope"].(string) + if scope != "user" { + t.Errorf("Unexpected value for scope: %v", scope) + } + expiresIn := tok.Metadata["expires_in"] + if expiresIn != float64(86400) { + t.Errorf("Unexpected non-numeric value for expires_in: %v", expiresIn) + } +} + +func TestConfig3LO_ExchangeJSONResponseExpiry(t *testing.T) { + seconds := int32(day.Seconds()) + for _, c := range []struct { + name string + expires string + want bool + nullExpires bool + }{ + {"normal", fmt.Sprintf(`"expires_in": %d`, seconds), true, false}, + {"null", fmt.Sprintf(`"expires_in": null`), true, true}, + {"wrong_type", `"expires_in": false`, false, false}, + {"wrong_type2", `"expires_in": {}`, false, false}, + {"wrong_value", `"expires_in": "zzz"`, false, false}, + } { + t.Run(c.name, func(t *testing.T) { + testConfig3LO_ExchangeJSONResponseExpiry(t, c.expires, c.want, c.nullExpires) + }) + } +} + +func testConfig3LO_ExchangeJSONResponseExpiry(t *testing.T, exp string, want, nullExpires bool) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(fmt.Sprintf(`{"access_token": "90d", "scope": "user", "token_type": "bearer", %s}`, exp))) + })) + defer ts.Close() + conf := newConf(ts.URL) + t1 := time.Now().Add(day) + tok, _, err := conf.exchange(context.Background(), "exchange-code") + t2 := t1.Add(day) + + if got := (err == nil); got != want { + if want { + t.Errorf("unexpected error: got %v", err) + } else { + t.Errorf("unexpected success") + } + } + if !want { + return + } + if !tok.IsValid() { + t.Fatalf("Token invalid. Got: %#v", tok) + } + expiry := tok.Expiry + + if nullExpires && expiry.IsZero() { + return + } + if expiry.Before(t1) || expiry.After(t2) { + t.Errorf("Unexpected value for Expiry: %v (should be between %v and %v)", expiry, t1, t2) + } +} + +func TestConfig3LO_ExchangeBadResponse(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"scope": "user", "token_type": "bearer"}`)) + })) + defer ts.Close() + conf := newConf(ts.URL) + _, _, err := conf.exchange(context.Background(), "code") + if err == nil { + t.Error("expected error from missing access_token") + } +} + +func TestConfig3LO_ExchangeBadResponseType(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"access_token":123, "scope": "user", "token_type": "bearer"}`)) + })) + defer ts.Close() + conf := newConf(ts.URL) + _, _, err := conf.exchange(context.Background(), "exchange-code") + if err == nil { + t.Error("expected error from non-string access_token") + } +} + +func TestConfig3LO_RefreshTokenReplacement(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"access_token":"ACCESS_TOKEN", "scope": "user", "token_type": "bearer", "refresh_token": "NEW_REFRESH_TOKEN"}`)) + return + })) + defer ts.Close() + conf := newConf(ts.URL) + const oldRefreshToken = "OLD_REFRESH_TOKEN" + tp := conf.TokenProvider(oldRefreshToken) + _, err := tp.Token(context.Background()) + if err != nil { + t.Errorf("got err = %v; want none", err) + return + } + innerTP := tp.(*cachedTokenProvider).tp.(*tokenProvider3LO) + if want := "NEW_REFRESH_TOKEN"; innerTP.refreshToken != want { + t.Errorf("RefreshToken = %q; want %q", innerTP.refreshToken, want) + } +} + +func TestConfig3LO_RefreshTokenPreservation(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"access_token":"ACCESS_TOKEN", "scope": "user", "token_type": "bearer"}`)) + return + })) + defer ts.Close() + conf := newConf(ts.URL) + const oldRefreshToken = "OLD_REFRESH_TOKEN" + tp := conf.TokenProvider(oldRefreshToken) + _, err := tp.Token(context.Background()) + if err != nil { + t.Fatalf("got err = %v; want none", err) + } + innerTP := tp.(*cachedTokenProvider).tp.(*tokenProvider3LO) + if innerTP.refreshToken != oldRefreshToken { + t.Errorf("RefreshToken = %q; want %q", innerTP.refreshToken, oldRefreshToken) + } +} + +func TestConfig3LO_AuthHandlerExchangeSuccess(t *testing.T) { + authhandler := func(authCodeURL string) (string, string, error) { + if authCodeURL == "testAuthCodeURL?client_id=testClientID&response_type=code&scope=pubsub&state=testState" { + return "testCode", "testState", nil + } + return "", "", fmt.Errorf("invalid authCodeURL: %q", authCodeURL) + } + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + if r.Form.Get("code") == "testCode" { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{ + "access_token": "90d64460d14870c08c81352a05dedd3465940a7c", + "scope": "pubsub", + "token_type": "bearer", + "expires_in": 3600 + }`)) + } + })) + defer ts.Close() + + conf := &Config3LO{ + ClientID: "testClientID", + Scopes: []string{"pubsub"}, + AuthURL: "testAuthCodeURL", + TokenURL: ts.URL, + AuthStyle: AuthStyleInHeader, + } + + tok, err := conf.TokenProviderWithAuthHandler(AuthHandlerOptions{ + State: "testState", + AuthorizationHandler: authhandler, + }).Token(context.Background()) + if err != nil { + t.Fatal(err) + } + if !tok.IsValid() { + t.Errorf("got invalid token: %v", tok) + } + if got, want := tok.Value, "90d64460d14870c08c81352a05dedd3465940a7c"; got != want { + t.Errorf("access token = %q; want %q", got, want) + } + if got, want := tok.Type, "bearer"; got != want { + t.Errorf("token type = %q; want %q", got, want) + } + if got := tok.Expiry.IsZero(); got { + t.Errorf("token expiry is zero = %v, want false", got) + } + scope := tok.Metadata["scope"].(string) + if got, want := scope, "pubsub"; got != want { + t.Errorf("scope = %q; want %q", got, want) + } +} + +func TestConfig3LO_AuthHandlerExchangeStateMismatch(t *testing.T) { + authhandler := func(authCodeURL string) (string, string, error) { + return "testCode", "testStateMismatch", nil + } + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{ + "access_token": "90d64460d14870c08c81352a05dedd3465940a7c", + "scope": "pubsub", + "token_type": "bearer", + "expires_in": 3600 + }`)) + })) + defer ts.Close() + + conf := &Config3LO{ + ClientID: "testClientID", + Scopes: []string{"pubsub"}, + AuthURL: "testAuthCodeURL", + TokenURL: ts.URL, + } + + _, err := conf.TokenProviderWithAuthHandler(AuthHandlerOptions{ + State: "testState", + AuthorizationHandler: authhandler, + }).Token(context.Background()) + if want_err := "auth: state mismatch in 3-legged-OAuth flow"; err == nil || err.Error() != want_err { + t.Errorf("err = %q; want %q", err, want_err) + } +} + +func TestConfig3LO_PKCEExchangeWithSuccess(t *testing.T) { + authhandler := func(authCodeURL string) (string, string, error) { + if authCodeURL == "testAuthCodeURL?client_id=testClientID&code_challenge=codeChallenge&code_challenge_method=plain&response_type=code&scope=pubsub&state=testState" { + return "testCode", "testState", nil + } + return "", "", fmt.Errorf("invalid authCodeURL: %q", authCodeURL) + } + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + if r.Form.Get("code") == "testCode" && r.Form.Get("code_verifier") == "codeChallenge" { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{ + "access_token": "90d64460d14870c08c81352a05dedd3465940a7c", + "scope": "pubsub", + "token_type": "bearer", + "expires_in": 3600 + }`)) + } + })) + defer ts.Close() + + conf := &Config3LO{ + ClientID: "testClientID", + Scopes: []string{"pubsub"}, + AuthURL: "testAuthCodeURL", + TokenURL: ts.URL, + AuthStyle: AuthStyleInParams, + } + pkce := PKCEConfig{ + Challenge: "codeChallenge", + ChallengeMethod: "plain", + Verifier: "codeChallenge", + } + + tok, err := conf.TokenProviderWithAuthHandler(AuthHandlerOptions{ + State: "testState", + AuthorizationHandler: authhandler, + PKCEConfig: &pkce, + }).Token(context.Background()) + if err != nil { + t.Fatal(err) + } + if !tok.IsValid() { + t.Errorf("got invalid token: %v", tok) + } + if got, want := tok.Value, "90d64460d14870c08c81352a05dedd3465940a7c"; got != want { + t.Errorf("access token = %q; want %q", got, want) + } + if got, want := tok.Type, "bearer"; got != want { + t.Errorf("token type = %q; want %q", got, want) + } + if got := tok.Expiry.IsZero(); got { + t.Errorf("token expiry is zero = %v, want false", got) + } + scope := tok.Metadata["scope"].(string) + if got, want := scope, "pubsub"; got != want { + t.Errorf("scope = %q; want %q", got, want) + } +} diff --git a/go.work b/go.work index 908db397fd39..661c9eb6d22e 100644 --- a/go.work +++ b/go.work @@ -18,6 +18,7 @@ use ( ./artifactregistry ./asset ./assuredworkloads + ./auth ./automl ./baremetalsolution ./batch diff --git a/release-please-config-individual.json b/release-please-config-individual.json index b8766e0f2ea7..9b1266f6dd29 100644 --- a/release-please-config-individual.json +++ b/release-please-config-individual.json @@ -5,6 +5,9 @@ "separate-pull-requests": true, "tag-separator": "/", "packages": { + "auth": { + "component": "auth" + }, "bigquery": { "component": "bigquery" }, From 7923b13c432377d1188993a3e0c14d2030ed4c81 Mon Sep 17 00:00:00 2001 From: Cody Oss Date: Tue, 22 Aug 2023 11:12:11 -0500 Subject: [PATCH 2/5] fix vets --- auth/auth.go | 25 ++++++++++++++++--------- auth/internal/internal.go | 1 + auth/internal/jwt/jwt.go | 9 +++++++-- auth/threelegged.go | 20 +++++++++++--------- auth/threelegged_test.go | 22 +++++++++++----------- go.work.sum | 19 +------------------ 6 files changed, 47 insertions(+), 49 deletions(-) diff --git a/auth/auth.go b/auth/auth.go index ff0047eb7a56..1fec2d4cb464 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -117,7 +117,8 @@ func (ctpo *CachedTokenProviderOptions) expireEarly() time.Duration { return ctpo.ExpireEarly } -// May need to also pass a token for the user-auth flow? +// NewCachedTokenProvider wraps a [TokenProvider] to cache the tokens returned +// by the underlying provider. func NewCachedTokenProvider(tp TokenProvider, opts *CachedTokenProviderOptions) TokenProvider { if ctp, ok := tp.(*cachedTokenProvider); ok { return ctp @@ -152,6 +153,8 @@ func (c *cachedTokenProvider) Token(ctx context.Context) (*Token, error) { return t, nil } +// Error is a error associated with retrieving a [Token]. It can hold useful +// additional details for debugging. type Error struct { // Response is the HTTP response associated with error. The body will always // be already closed and consumed. @@ -183,6 +186,8 @@ func (r *Error) Error() string { return fmt.Sprintf("auth: cannot fetch token: %v\nResponse: %s", r.Response.StatusCode, r.Body) } +// Temporary returns true if the error is considered temporary and may be able +// to be retried. func (e *Error) Temporary() bool { if e.Response == nil { return false @@ -195,18 +200,18 @@ func (e *Error) Unwrap() error { return e.Err } -// AuthStyle describes how the token endpoint wants receive the ClientID and +// Style describes how the token endpoint wants receive the ClientID and // ClientSecret. -type AuthStyle int +type Style int const ( - // AuthStyleUnknown means the value has not been initiated. Sending this in + // StyleUnknown means the value has not been initiated. Sending this in // a request will cause the token exchange to fail. - AuthStyleUnknown AuthStyle = 0 - // AuthStyleInParams sends client info in the body of a POST request. - AuthStyleInParams AuthStyle = 1 - // AuthStyleInHeader sends client info using Basic Authorization header. - AuthStyleInHeader AuthStyle = 2 + StyleUnknown Style = 0 + // StyleInParams sends client info in the body of a POST request. + StyleInParams Style = 1 + // StyleInHeader sends client info using Basic Authorization header. + StyleInHeader Style = 2 ) // ConfigJWT2LO is the configuration settings for doing a 2-legged JWT OAuth2 flow. @@ -250,6 +255,8 @@ func (c *ConfigJWT2LO) client() *http.Client { return internal.CloneDefaultClient() } +// TokenProvider returns a [TokenProvider] based on the provided fields set on +// [ConfigJWT2LO]. func (c *ConfigJWT2LO) TokenProvider() TokenProvider { return tokenProvider2LO{c: c, Client: c.client()} } diff --git a/auth/internal/internal.go b/auth/internal/internal.go index c0dfbb80813f..b1e61f33ab29 100644 --- a/auth/internal/internal.go +++ b/auth/internal/internal.go @@ -27,6 +27,7 @@ import ( ) const ( + // TokenTypeBearer is the auth header prefix for bearer tokens. TokenTypeBearer = "Bearer" quotaProjectEnvVar = "GOOGLE_CLOUD_QUOTA_PROJECT" diff --git a/auth/internal/jwt/jwt.go b/auth/internal/jwt/jwt.go index a07db35df924..521ecdfe33f1 100644 --- a/auth/internal/jwt/jwt.go +++ b/auth/internal/jwt/jwt.go @@ -29,11 +29,15 @@ import ( ) const ( + // HeaderAlgRSA256 is the RS256 [Header.Algorithm]. HeaderAlgRSA256 = "RS256" - HeaderAlgES256 = "ES256" - HeaderType = "JWT" + // HeaderAlgES256 is the ES256 [Header.Algorithm]. + HeaderAlgES256 = "ES256" + // HeaderType is the standard [Header.Type]. + HeaderType = "JWT" ) +// Header represents a JWT header. type Header struct { Algorithm string `json:"alg"` Type string `json:"typ"` @@ -48,6 +52,7 @@ func (h *Header) encode() (string, error) { return base64.RawURLEncoding.EncodeToString(b), nil } +// Claims represents the claims set of a JWT. type Claims struct { // Iss is the issuer JWT claim. Iss string `json:"iss"` diff --git a/auth/threelegged.go b/auth/threelegged.go index f29776d3a6e8..c152e1705266 100644 --- a/auth/threelegged.go +++ b/auth/threelegged.go @@ -57,7 +57,7 @@ type Config3LO struct { // Optional. Client *http.Client // AuthStyle is used to describe how to client info in the token request. - AuthStyle AuthStyle + AuthStyle Style // EarlyTokenExpiry is the time before the token expires that it should be // refreshed. If not set the default value is 10 seconds. Optional. EarlyTokenExpiry time.Duration @@ -65,7 +65,7 @@ type Config3LO struct { pkceConf *PKCEConfig } -// PKCEParams holds parameters to support PKCE. +// PKCEConfig holds parameters to support PKCE. type PKCEConfig struct { // Challenge is the un-padded, base64-url-encoded string of the encrypted code verifier. Challenge string // The un-padded, base64-url-encoded string of the encrypted code verifier. @@ -135,7 +135,7 @@ func (c *Config3LO) authCodeURL(state string, values url.Values) string { return buf.String() } -// TokenProvider returns a TokenProvider based the the 3-legged OAuth2 +// TokenProvider returns a TokenProvider based on the 3-legged OAuth2 // configuration. The TokenProvider is caches and auto-refreshes tokens by // default. func (c *Config3LO) TokenProvider(refreshToken string) TokenProvider { @@ -144,9 +144,9 @@ func (c *Config3LO) TokenProvider(refreshToken string) TokenProvider { }) } -// AuthHandlerOptions provides a set of options to specify for doing a +// AuthenticationHandlerOptions provides a set of options to specify for doing a // 3-legged OAuth2 flow with a custom [AuthorizationHandler]. -type AuthHandlerOptions struct { +type AuthenticationHandlerOptions struct { // AuthorizationHandler specifies the handler used to for the authorization // part of the flow. AuthorizationHandler AuthorizationHandler @@ -157,7 +157,9 @@ type AuthHandlerOptions struct { PKCEConfig *PKCEConfig } -func (c *Config3LO) TokenProviderWithAuthHandler(opts AuthHandlerOptions) TokenProvider { +// TokenProviderWithAuthHandler returns a [TokenProvider] based on the 3-legged +// OAuth2 configuration and authentication handler options. +func (c *Config3LO) TokenProviderWithAuthHandler(opts AuthenticationHandlerOptions) TokenProvider { c.pkceConf = opts.PKCEConfig return NewCachedTokenProvider(&tokenProviderWithHandler{c: c, handler: opts.AuthorizationHandler, state: opts.State}, &CachedTokenProviderOptions{ ExpireEarly: c.EarlyTokenExpiry, @@ -236,10 +238,10 @@ func (tp tokenProviderWithHandler) Token(ctx context.Context) (*Token, error) { // fetchToken returns a Token, refresh token, and/or an error. func fetchToken(ctx context.Context, c *Config3LO, v url.Values) (*Token, string, error) { var refreshToken string - if c.AuthStyle == AuthStyleUnknown { + if c.AuthStyle == StyleUnknown { return nil, refreshToken, fmt.Errorf("auth: missing required field AuthStyle") } - if c.AuthStyle == AuthStyleInParams { + if c.AuthStyle == StyleInParams { if c.ClientID != "" { v.Set("client_id", c.ClientID) } @@ -252,7 +254,7 @@ func fetchToken(ctx context.Context, c *Config3LO, v url.Values) (*Token, string return nil, refreshToken, err } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - if c.AuthStyle == AuthStyleInHeader { + if c.AuthStyle == StyleInHeader { req.SetBasicAuth(url.QueryEscape(c.ClientID), url.QueryEscape(c.ClientSecret)) } diff --git a/auth/threelegged_test.go b/auth/threelegged_test.go index 018d9d3f09eb..24ddcbd9cf3a 100644 --- a/auth/threelegged_test.go +++ b/auth/threelegged_test.go @@ -43,7 +43,7 @@ func newConf(url string) *Config3LO { Scopes: []string{"scope1", "scope2"}, AuthURL: url + "/auth", TokenURL: url + "/token", - AuthStyle: AuthStyleInHeader, + AuthStyle: StyleInHeader, } } @@ -214,18 +214,18 @@ func TestConfig3LO_ExchangeJSONResponseExpiry(t *testing.T) { nullExpires bool }{ {"normal", fmt.Sprintf(`"expires_in": %d`, seconds), true, false}, - {"null", fmt.Sprintf(`"expires_in": null`), true, true}, + {"null", `"expires_in": null`, true, true}, {"wrong_type", `"expires_in": false`, false, false}, {"wrong_type2", `"expires_in": {}`, false, false}, {"wrong_value", `"expires_in": "zzz"`, false, false}, } { t.Run(c.name, func(t *testing.T) { - testConfig3LO_ExchangeJSONResponseExpiry(t, c.expires, c.want, c.nullExpires) + testConfig3LOExchangeJSONResponseExpiry(t, c.expires, c.want, c.nullExpires) }) } } -func testConfig3LO_ExchangeJSONResponseExpiry(t *testing.T, exp string, want, nullExpires bool) { +func testConfig3LOExchangeJSONResponseExpiry(t *testing.T, exp string, want, nullExpires bool) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.Write([]byte(fmt.Sprintf(`{"access_token": "90d", "scope": "user", "token_type": "bearer", %s}`, exp))) @@ -353,10 +353,10 @@ func TestConfig3LO_AuthHandlerExchangeSuccess(t *testing.T) { Scopes: []string{"pubsub"}, AuthURL: "testAuthCodeURL", TokenURL: ts.URL, - AuthStyle: AuthStyleInHeader, + AuthStyle: StyleInHeader, } - tok, err := conf.TokenProviderWithAuthHandler(AuthHandlerOptions{ + tok, err := conf.TokenProviderWithAuthHandler(AuthenticationHandlerOptions{ State: "testState", AuthorizationHandler: authhandler, }).Token(context.Background()) @@ -404,12 +404,12 @@ func TestConfig3LO_AuthHandlerExchangeStateMismatch(t *testing.T) { TokenURL: ts.URL, } - _, err := conf.TokenProviderWithAuthHandler(AuthHandlerOptions{ + _, err := conf.TokenProviderWithAuthHandler(AuthenticationHandlerOptions{ State: "testState", AuthorizationHandler: authhandler, }).Token(context.Background()) - if want_err := "auth: state mismatch in 3-legged-OAuth flow"; err == nil || err.Error() != want_err { - t.Errorf("err = %q; want %q", err, want_err) + if wantErr := "auth: state mismatch in 3-legged-OAuth flow"; err == nil || err.Error() != wantErr { + t.Errorf("err = %q; want %q", err, wantErr) } } @@ -440,7 +440,7 @@ func TestConfig3LO_PKCEExchangeWithSuccess(t *testing.T) { Scopes: []string{"pubsub"}, AuthURL: "testAuthCodeURL", TokenURL: ts.URL, - AuthStyle: AuthStyleInParams, + AuthStyle: StyleInParams, } pkce := PKCEConfig{ Challenge: "codeChallenge", @@ -448,7 +448,7 @@ func TestConfig3LO_PKCEExchangeWithSuccess(t *testing.T) { Verifier: "codeChallenge", } - tok, err := conf.TokenProviderWithAuthHandler(AuthHandlerOptions{ + tok, err := conf.TokenProviderWithAuthHandler(AuthenticationHandlerOptions{ State: "testState", AuthorizationHandler: authhandler, PKCEConfig: &pkce, diff --git a/go.work.sum b/go.work.sum index 8d1e2b9d1181..d72e236a3d40 100644 --- a/go.work.sum +++ b/go.work.sum @@ -1,5 +1,4 @@ cloud.google.com/go/gaming v1.9.0 h1:7vEhFnZmd931Mo7sZ6pJy7uQPDxF7m7v8xtBheG08tc= -cloud.google.com/go/gaming v1.10.1/go.mod h1:XQQvtfP8Rb9Rxnxm5wFVpAp9zCQkJi2bLIb7iHGwB3s= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= github.com/bwesterb/go-ristretto v1.2.3/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0= @@ -8,33 +7,17 @@ github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObk github.com/elazarl/goproxy v0.0.0-20221015165544-a0805db90819/go.mod h1:Ro8st/ElPeALwNFlcTpWmkr6IoMFfkjXAvTHpevnDsM= github.com/gliderlabs/ssh v0.3.5/go.mod h1:8XB4KraRrX39qHhT6yxPsHedjA08I/uBVwj4xC+/+z4= github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20230305113008-0c11038e723f/go.mod h1:8LHG1a3SRW71ettAD/jW13h8c6AqjVSeL11RAdgaqpo= -github.com/google/go-pkcs11 v0.2.0/go.mod h1:6eQoGcuNJpa7jnd5pMGdkSaQpNDYvPlXWMcjXXThLlY= github.com/google/s2a-go v0.1.3/go.mod h1:Ej+mSEMGRnqRzjc7VtF+jdBwYG5fuJfiZ8ELkjEwM0A= -github.com/googleapis/enterprise-certificate-proxy v0.2.4/go.mod h1:AwSRAtLfXpU5Nm3pW+v7rGDHp09LsPtGY9MduiEsR9k= github.com/googleapis/gax-go/v2 v2.9.1/go.mod h1:4FG3gMrVZlyMp5itSYKMU9z/lBE7+SbnUOvzH2HqbEY= github.com/ianlancetaylor/demangle v0.0.0-20230524184225-eabc099b10ab/go.mod h1:gx7rwoVhcfuVKG5uya9Hs3Sxj7EIvldVofAWIUtGouw= github.com/jessevdk/go-flags v1.5.0/go.mod h1:Fw0T6WPc1dYxT4mKEZRfG5kJhaTDP9pj1c2EWnYs/m4= github.com/mmcloughlin/avo v0.5.0/go.mod h1:ChHFdoV7ql95Wi7vuq2YT1bwCJqiWdZrQ1im3VujLYM= github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= -github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I= golang.org/x/mod v0.9.0 h1:KENHtAZL2y3NLMYZeHY9DW8HW8V+kQyJsY/V9JlKvCs= -golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/net v0.11.0/go.mod h1:2L/ixqYpgIVXmeoSA/4Lu7BzTG4KIyPIryS4IsOd1oQ= -golang.org/x/oauth2 v0.7.0/go.mod h1:hPLQkd9LyjfXTiRohC/41GhcFqxisoUQ99sCUOHO9x4= -golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= -golang.org/x/term v0.9.0/go.mod h1:M6DEAAIenWoTxdKrOltXcmDY3rSplQUkrvaDU5FcQyo= -golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o= -golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/tools v0.7.0 h1:W4OVu8VVOaIO0yzWMNdepAulS7YfoS3Zabrm8DOXXU4= google.golang.org/api v0.123.0/go.mod h1:gcitW0lvnyWjSp9nKxAbdHKIZ6vF4aajGueeslZOyms= -google.golang.org/api v0.128.0/go.mod h1:Y611qgqaE92On/7g65MQgxYul3c0rEB894kniWLY750= google.golang.org/genproto v0.0.0-20230526161137-0005af68ea54/go.mod h1:zqTuNwFlFRsw5zIts5VnzLQxSRqh+CGOTVMlYbY0Eyk= -google.golang.org/genproto/googleapis/api v0.0.0-20230629202037-9506855d4529/go.mod h1:vHYtlOoi6TsQ3Uk2yxR7NI5z8uoV+3pZtR4jmHIkRig= google.golang.org/genproto/googleapis/bytestream v0.0.0-20230629202037-9506855d4529/go.mod h1:ylj+BE99M198VPbBh6A8d9n3w8fChvyLK3wwBOjXBFA= -google.golang.org/genproto/googleapis/bytestream v0.0.0-20230711160842-782d3b101e98/go.mod h1:3QoBVwTHkXbY1oRGzlhwhOykfcATQN43LJ6iT8Wy8kE= google.golang.org/genproto/googleapis/bytestream v0.0.0-20230720185612-659f7aaaa771/go.mod h1:3QoBVwTHkXbY1oRGzlhwhOykfcATQN43LJ6iT8Wy8kE= -google.golang.org/grpc v1.52.3/go.mod h1:pu6fVzoFb+NBYNAvQL08ic+lvB2IojljRYuun5vorUY= -google.golang.org/grpc v1.56.1/go.mod h1:I9bI3vqKfayGqPUAwGdOSu7kt6oIJLixfffKrpXqQ9s= +google.golang.org/genproto/googleapis/rpc v0.0.0-20230731190214-cbb8c96f2d6d/go.mod h1:TUfxEVdsvPg18p6AslUXFoLdpED4oBnGwyqk3dV1XzM= From afff4a0a2bd2dd1a747cc870eb9d0aed70a9fec8 Mon Sep 17 00:00:00 2001 From: Cody Oss Date: Tue, 22 Aug 2023 16:18:04 -0500 Subject: [PATCH 3/5] - add constructors - switch to use * types - use the Option name consistently --- auth/auth.go | 38 +++++++------- auth/auth_test.go | 66 ++++++++++++++++-------- auth/example_test.go | 9 ++-- auth/threelegged.go | 84 +++++++++++++++++-------------- auth/threelegged_test.go | 106 +++++++++++++++++++++------------------ 5 files changed, 173 insertions(+), 130 deletions(-) diff --git a/auth/auth.go b/auth/auth.go index 1fec2d4cb464..875aa23a80df 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -214,8 +214,8 @@ const ( StyleInHeader Style = 2 ) -// ConfigJWT2LO is the configuration settings for doing a 2-legged JWT OAuth2 flow. -type ConfigJWT2LO struct { +// Options2LO is the configuration settings for doing a 2-legged JWT OAuth2 flow. +type Options2LO struct { // Email is the OAuth2 client ID. This value is set as the "iss" in the // JWT. Email string @@ -248,46 +248,46 @@ type ConfigJWT2LO struct { UseIDToken bool } -func (c *ConfigJWT2LO) client() *http.Client { +func (c *Options2LO) client() *http.Client { if c.Client != nil { return c.Client } return internal.CloneDefaultClient() } -// TokenProvider returns a [TokenProvider] based on the provided fields set on -// [ConfigJWT2LO]. -func (c *ConfigJWT2LO) TokenProvider() TokenProvider { - return tokenProvider2LO{c: c, Client: c.client()} +// New2LOTokenProvider returns a [TokenProvider] from the provided options. +func New2LOTokenProvider(opts *Options2LO) (TokenProvider, error) { + // TODO(codyoss): add validation + return tokenProvider2LO{opts: opts, Client: opts.client()}, nil } type tokenProvider2LO struct { - c *ConfigJWT2LO + opts *Options2LO Client *http.Client } func (tp tokenProvider2LO) Token(ctx context.Context) (*Token, error) { - pk, err := internal.ParseKey(tp.c.PrivateKey) + pk, err := internal.ParseKey(tp.opts.PrivateKey) if err != nil { return nil, err } claimSet := &jwt.Claims{ - Iss: tp.c.Email, - Scope: strings.Join(tp.c.Scopes, " "), - Aud: tp.c.TokenURL, - AdditionalClaims: tp.c.PrivateClaims, + Iss: tp.opts.Email, + Scope: strings.Join(tp.opts.Scopes, " "), + Aud: tp.opts.TokenURL, + AdditionalClaims: tp.opts.PrivateClaims, } - if subject := tp.c.Subject; subject != "" { + if subject := tp.opts.Subject; subject != "" { claimSet.Sub = subject } - if t := tp.c.Expires; t > 0 { + if t := tp.opts.Expires; t > 0 { claimSet.Exp = time.Now().Add(t).Unix() } - if aud := tp.c.Audience; aud != "" { + if aud := tp.opts.Audience; aud != "" { claimSet.Aud = aud } h := *defaultHeader - h.KeyID = tp.c.PrivateKeyID + h.KeyID = tp.opts.PrivateKeyID payload, err := jwt.EncodeJWS(&h, claimSet, pk) if err != nil { return nil, err @@ -295,7 +295,7 @@ func (tp tokenProvider2LO) Token(ctx context.Context) (*Token, error) { v := url.Values{} v.Set("grant_type", defaultGrantType) v.Set("assertion", payload) - resp, err := tp.Client.PostForm(tp.c.TokenURL, v) + resp, err := tp.Client.PostForm(tp.opts.TokenURL, v) if err != nil { return nil, fmt.Errorf("auth: cannot fetch token: %w", err) } @@ -338,7 +338,7 @@ func (tp tokenProvider2LO) Token(ctx context.Context) (*Token, error) { } token.Expiry = time.Unix(claimSet.Exp, 0) } - if tp.c.UseIDToken { + if tp.opts.UseIDToken { if tokenRes.IDToken == "" { return nil, fmt.Errorf("auth: response doesn't have JWT token") } diff --git a/auth/auth_test.go b/auth/auth_test.go index 4be4cbb45db0..cdb3819b2d4d 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -190,12 +190,16 @@ func TestConfigJWT2LO_JSONResponse(t *testing.T) { })) defer ts.Close() - conf := &ConfigJWT2LO{ + opts := &Options2LO{ Email: "aaa@xxx.com", PrivateKey: fakePrivateKey, TokenURL: ts.URL, } - tok, err := conf.TokenProvider().Token(context.Background()) + tp, err := New2LOTokenProvider(opts) + if err != nil { + t.Fatal(err) + } + tok, err := tp.Token(context.Background()) if err != nil { t.Fatal(err) } @@ -224,12 +228,16 @@ func TestConfigJWT2LO_BadResponse(t *testing.T) { })) defer ts.Close() - conf := &ConfigJWT2LO{ + opts := &Options2LO{ Email: "aaa@xxx.com", PrivateKey: fakePrivateKey, TokenURL: ts.URL, } - tok, err := conf.TokenProvider().Token(context.Background()) + tp, err := New2LOTokenProvider(opts) + if err != nil { + t.Fatal(err) + } + tok, err := tp.Token(context.Background()) if err != nil { t.Fatal(err) } @@ -257,12 +265,16 @@ func TestConfigJWT2LO_BadResponseType(t *testing.T) { w.Write([]byte(`{"access_token":123, "scope": "user", "token_type": "bearer"}`)) })) defer ts.Close() - conf := &ConfigJWT2LO{ + opts := &Options2LO{ Email: "aaa@xxx.com", PrivateKey: fakePrivateKey, TokenURL: ts.URL, } - tok, err := conf.TokenProvider().Token(context.Background()) + tp, err := New2LOTokenProvider(opts) + if err != nil { + t.Fatal(err) + } + tok, err := tp.Token(context.Background()) if err == nil { t.Error("got a token; expected error") if got, want := tok.Value, ""; got != want { @@ -287,14 +299,18 @@ func TestConfigJWT2LO_Assertion(t *testing.T) { })) defer ts.Close() - conf := &ConfigJWT2LO{ + opts := &Options2LO{ Email: "aaa@xxx.com", PrivateKey: fakePrivateKey, PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ", TokenURL: ts.URL, } - _, err := conf.TokenProvider().Token(context.Background()) + tp, err := New2LOTokenProvider(opts) + if err != nil { + t.Fatal(err) + } + _, err = tp.Token(context.Background()) if err != nil { t.Fatalf("Failed to fetch token: %v", err) } @@ -339,7 +355,7 @@ func TestConfigJWT2LO_AssertionPayload(t *testing.T) { })) defer ts.Close() - for _, conf := range []*ConfigJWT2LO{ + for _, opts := range []*Options2LO{ { Email: "aaa1@xxx.com", PrivateKey: fakePrivateKey, @@ -364,8 +380,12 @@ func TestConfigJWT2LO_AssertionPayload(t *testing.T) { }, }, } { - t.Run(conf.Email, func(t *testing.T) { - _, err := conf.TokenProvider().Token(context.Background()) + t.Run(opts.Email, func(t *testing.T) { + tp, err := New2LOTokenProvider(opts) + if err != nil { + t.Fatal(err) + } + _, err = tp.Token(context.Background()) if err != nil { t.Fatalf("Failed to fetch token: %v", err) } @@ -384,29 +404,29 @@ func TestConfigJWT2LO_AssertionPayload(t *testing.T) { t.Errorf("failed to unmarshal json token payload = %q; err = %v", gotjson, err) } - if got, want := claimSet.Iss, conf.Email; got != want { + if got, want := claimSet.Iss, opts.Email; got != want { t.Errorf("payload email = %q; want %q", got, want) } - if got, want := claimSet.Scope, strings.Join(conf.Scopes, " "); got != want { + if got, want := claimSet.Scope, strings.Join(opts.Scopes, " "); got != want { t.Errorf("payload scope = %q; want %q", got, want) } - aud := conf.TokenURL - if conf.Audience != "" { - aud = conf.Audience + aud := opts.TokenURL + if opts.Audience != "" { + aud = opts.Audience } if got, want := claimSet.Aud, aud; got != want { t.Errorf("payload audience = %q; want %q", got, want) } - if got, want := claimSet.Sub, conf.Subject; got != want { + if got, want := claimSet.Sub, opts.Subject; got != want { t.Errorf("payload subject = %q; want %q", got, want) } - if len(conf.PrivateClaims) > 0 { + if len(opts.PrivateClaims) > 0 { var got interface{} if err := json.Unmarshal(gotjson, &got); err != nil { t.Errorf("failed to parse payload; err = %q", err) } m := got.(map[string]interface{}) - for v, k := range conf.PrivateClaims { + for v, k := range opts.PrivateClaims { if !reflect.DeepEqual(m[v], k) { t.Errorf("payload private claims key = %q: got %#v; want %#v", v, m[v], k) } @@ -424,13 +444,17 @@ func TestConfigJWT2LO_TokenError(t *testing.T) { })) defer ts.Close() - conf := &ConfigJWT2LO{ + opts := &Options2LO{ Email: "aaa@xxx.com", PrivateKey: fakePrivateKey, TokenURL: ts.URL, } - _, err := conf.TokenProvider().Token(context.Background()) + tp, err := New2LOTokenProvider(opts) + if err != nil { + t.Fatal(err) + } + _, err = tp.Token(context.Background()) if err == nil { t.Fatalf("got no error, expected one") } diff --git a/auth/example_test.go b/auth/example_test.go index bf6d733ae3c2..d4beed7bd5f9 100644 --- a/auth/example_test.go +++ b/auth/example_test.go @@ -18,10 +18,10 @@ import ( "cloud.google.com/go/auth" ) -func ExampleConfigJWT2LO() { +func ExampleNew2LOTokenProvider() { // Your credentials should be obtained from the Google // Developer Console (https://console.developers.google.com). - conf := &auth.ConfigJWT2LO{ + opts := &auth.Options2LO{ Email: "xxx@developer.gserviceaccount.com", // The contents of your RSA private key or your PEM file // that contains a private key. @@ -45,7 +45,10 @@ func ExampleConfigJWT2LO() { Subject: "user@example.com", } - tp := conf.TokenProvider() + tp, err := auth.New2LOTokenProvider(opts) + if err != nil { + // handler error + } // TODO(codyoss): Fixup once more code is merged // client, err := httptransport.NewClient(&httptransport.Options{ // TokenProvider: tp, diff --git a/auth/threelegged.go b/auth/threelegged.go index c152e1705266..f14b1eb0c452 100644 --- a/auth/threelegged.go +++ b/auth/threelegged.go @@ -36,8 +36,8 @@ import ( // state upon approval. type AuthorizationHandler func(authCodeURL string) (code string, state string, err error) -// Config3LO is the configuration settings for doing a 3-legged OAuth2 flow. -type Config3LO struct { +// Options3LO are the options for doing a 3-legged OAuth2 flow. +type Options3LO struct { // ClientID is the application's ID. ClientID string // ClientSecret is the application's secret. @@ -62,7 +62,9 @@ type Config3LO struct { // refreshed. If not set the default value is 10 seconds. Optional. EarlyTokenExpiry time.Duration - pkceConf *PKCEConfig + // AuthHandlerOpts provides a set of options for doing a + // 3-legged OAuth2 flow with a custom [AuthorizationHandler]. Optional. + AuthHandlerOpts *AuthorizationHandlerOptions } // PKCEConfig holds parameters to support PKCE. @@ -93,7 +95,7 @@ func (e *tokenJSON) expiry() (t time.Time) { return } -func (c *Config3LO) client() *http.Client { +func (c *Options3LO) client() *http.Client { if c.Client != nil { return c.Client } @@ -101,7 +103,7 @@ func (c *Config3LO) client() *http.Client { } // authCodeURL returns a URL that points to a OAuth2 consent page. -func (c *Config3LO) authCodeURL(state string, values url.Values) string { +func (c *Options3LO) authCodeURL(state string, values url.Values) string { var buf bytes.Buffer buf.WriteString(c.AuthURL) v := url.Values{ @@ -117,11 +119,15 @@ func (c *Config3LO) authCodeURL(state string, values url.Values) string { if state != "" { v.Set("state", state) } - if c.pkceConf != nil && c.pkceConf.Challenge != "" && c.pkceConf.ChallengeMethod != "" { - v.Set(codeChallengeKey, c.pkceConf.Challenge) - } - if c.pkceConf != nil && c.pkceConf.ChallengeMethod != "" { - v.Set(codeChallengeMethodKey, c.pkceConf.ChallengeMethod) + if c.AuthHandlerOpts != nil { + if c.AuthHandlerOpts.PKCEConfig != nil && + c.AuthHandlerOpts.PKCEConfig.Challenge != "" { + v.Set(codeChallengeKey, c.AuthHandlerOpts.PKCEConfig.Challenge) + } + if c.AuthHandlerOpts.PKCEConfig != nil && + c.AuthHandlerOpts.PKCEConfig.ChallengeMethod != "" { + v.Set(codeChallengeMethodKey, c.AuthHandlerOpts.PKCEConfig.ChallengeMethod) + } } for k := range values { v.Set(k, v.Get(k)) @@ -135,21 +141,25 @@ func (c *Config3LO) authCodeURL(state string, values url.Values) string { return buf.String() } -// TokenProvider returns a TokenProvider based on the 3-legged OAuth2 +// New3LOTokenProvider returns a [TokenProvider] based on the 3-legged OAuth2 // configuration. The TokenProvider is caches and auto-refreshes tokens by // default. -func (c *Config3LO) TokenProvider(refreshToken string) TokenProvider { - return NewCachedTokenProvider(&tokenProvider3LO{config: c, refreshToken: refreshToken, client: c.client()}, &CachedTokenProviderOptions{ - ExpireEarly: c.EarlyTokenExpiry, - }) +func New3LOTokenProvider(refreshToken string, opts *Options3LO) (TokenProvider, error) { + if opts.AuthHandlerOpts != nil { + return new3LOTokenProviderWithAuthHandler(opts), nil + } + // TODO(codyoss): validate the things + return NewCachedTokenProvider(&tokenProvider3LO{opts: opts, refreshToken: refreshToken, client: opts.client()}, &CachedTokenProviderOptions{ + ExpireEarly: opts.EarlyTokenExpiry, + }), nil } -// AuthenticationHandlerOptions provides a set of options to specify for doing a +// AuthorizationHandlerOptions provides a set of options to specify for doing a // 3-legged OAuth2 flow with a custom [AuthorizationHandler]. -type AuthenticationHandlerOptions struct { +type AuthorizationHandlerOptions struct { // AuthorizationHandler specifies the handler used to for the authorization // part of the flow. - AuthorizationHandler AuthorizationHandler + Handler AuthorizationHandler // State is used verify that the "state" is identical in the request and // response before exchanging the auth code for OAuth2 token. State string @@ -157,18 +167,15 @@ type AuthenticationHandlerOptions struct { PKCEConfig *PKCEConfig } -// TokenProviderWithAuthHandler returns a [TokenProvider] based on the 3-legged -// OAuth2 configuration and authentication handler options. -func (c *Config3LO) TokenProviderWithAuthHandler(opts AuthenticationHandlerOptions) TokenProvider { - c.pkceConf = opts.PKCEConfig - return NewCachedTokenProvider(&tokenProviderWithHandler{c: c, handler: opts.AuthorizationHandler, state: opts.State}, &CachedTokenProviderOptions{ - ExpireEarly: c.EarlyTokenExpiry, +func new3LOTokenProviderWithAuthHandler(opts *Options3LO) TokenProvider { + return NewCachedTokenProvider(&tokenProviderWithHandler{opts: opts, state: opts.AuthHandlerOpts.State}, &CachedTokenProviderOptions{ + ExpireEarly: opts.EarlyTokenExpiry, }) } // exchange handles the final exchange portion of the 3lo flow. Returns a Token, // refreshToken, and error. -func (c *Config3LO) exchange(ctx context.Context, code string) (*Token, string, error) { +func (c *Options3LO) exchange(ctx context.Context, code string) (*Token, string, error) { // Build request v := url.Values{ "grant_type": {"authorization_code"}, @@ -177,8 +184,10 @@ func (c *Config3LO) exchange(ctx context.Context, code string) (*Token, string, if c.RedirectURL != "" { v.Set("redirect_uri", c.RedirectURL) } - if c.pkceConf != nil && c.pkceConf.Verifier != "" { - v.Set(codeVerifierKey, c.pkceConf.Verifier) + if c.AuthHandlerOpts != nil && + c.AuthHandlerOpts.PKCEConfig != nil && + c.AuthHandlerOpts.PKCEConfig.Verifier != "" { + v.Set(codeVerifierKey, c.AuthHandlerOpts.PKCEConfig.Verifier) } for k := range c.URLParams { v.Set(k, c.URLParams.Get(k)) @@ -189,7 +198,7 @@ func (c *Config3LO) exchange(ctx context.Context, code string) (*Token, string, // This struct is not safe for concurrent access alone, but the way it is used // in this package by wrapping it with a cachedTokenProvider makes it so. type tokenProvider3LO struct { - config *Config3LO + opts *Options3LO client *http.Client refreshToken string } @@ -202,11 +211,11 @@ func (tp *tokenProvider3LO) Token(ctx context.Context) (*Token, error) { "grant_type": {"refresh_token"}, "refresh_token": {tp.refreshToken}, } - for k := range tp.config.URLParams { - v.Set(k, tp.config.URLParams.Get(k)) + for k := range tp.opts.URLParams { + v.Set(k, tp.opts.URLParams.Get(k)) } - tk, rt, err := fetchToken(ctx, tp.config, v) + tk, rt, err := fetchToken(ctx, tp.opts, v) if err != nil { return nil, err } @@ -217,26 +226,25 @@ func (tp *tokenProvider3LO) Token(ctx context.Context) (*Token, error) { } type tokenProviderWithHandler struct { - c *Config3LO - handler AuthorizationHandler - state string + opts *Options3LO + state string } func (tp tokenProviderWithHandler) Token(ctx context.Context) (*Token, error) { - url := tp.c.authCodeURL(tp.state, nil) - code, state, err := tp.handler(url) + url := tp.opts.authCodeURL(tp.state, nil) + code, state, err := tp.opts.AuthHandlerOpts.Handler(url) if err != nil { return nil, err } if state != tp.state { return nil, errors.New("auth: state mismatch in 3-legged-OAuth flow") } - tok, _, err := tp.c.exchange(ctx, code) + tok, _, err := tp.opts.exchange(ctx, code) return tok, err } // fetchToken returns a Token, refresh token, and/or an error. -func fetchToken(ctx context.Context, c *Config3LO, v url.Values) (*Token, string, error) { +func fetchToken(ctx context.Context, c *Options3LO, v url.Values) (*Token, string, error) { var refreshToken string if c.AuthStyle == StyleUnknown { return nil, refreshToken, fmt.Errorf("auth: missing required field AuthStyle") diff --git a/auth/threelegged_test.go b/auth/threelegged_test.go index 24ddcbd9cf3a..2c2c4d2b5384 100644 --- a/auth/threelegged_test.go +++ b/auth/threelegged_test.go @@ -27,16 +27,8 @@ import ( const day = 24 * time.Hour -type mockTransport struct { - rt func(req *http.Request) (resp *http.Response, err error) -} - -func (t *mockTransport) RoundTrip(req *http.Request) (resp *http.Response, err error) { - return t.rt(req) -} - -func newConf(url string) *Config3LO { - return &Config3LO{ +func newOpts(url string) *Options3LO { + return &Options3LO{ ClientID: "CLIENT_ID", ClientSecret: "CLIENT_SECRET", RedirectURL: "REDIRECT_URL", @@ -57,7 +49,7 @@ func TestConfig3LO_URLUnsafe(t *testing.T) { w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer")) })) defer ts.Close() - conf := newConf(ts.URL) + conf := newOpts(ts.URL) conf.ClientID = "CLIENT_ID??" conf.ClientSecret = "CLIENT_SECRET??" _, _, err := conf.exchange(context.Background(), "exchange-code") @@ -90,7 +82,7 @@ func TestConfig3LO_StandardExchange(t *testing.T) { w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer")) })) defer ts.Close() - conf := newConf(ts.URL) + conf := newOpts(ts.URL) tok, _, err := conf.exchange(context.Background(), "exchange-code") if err != nil { t.Error(err) @@ -134,7 +126,7 @@ func TestConfig3LO_ExchangeCustomParams(t *testing.T) { w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer")) })) defer ts.Close() - conf := newConf(ts.URL) + conf := newOpts(ts.URL) conf.URLParams = url.Values{} conf.URLParams.Set("foo", "bar") @@ -181,7 +173,7 @@ func TestConfig3LO_ExchangeJSONResponse(t *testing.T) { w.Write([]byte(`{"access_token": "90d64460d14870c08c81352a05dedd3465940a7c", "scope": "user", "token_type": "bearer", "expires_in": 86400}`)) })) defer ts.Close() - conf := newConf(ts.URL) + conf := newOpts(ts.URL) tok, _, err := conf.exchange(context.Background(), "exchange-code") if err != nil { t.Error(err) @@ -231,7 +223,7 @@ func testConfig3LOExchangeJSONResponseExpiry(t *testing.T, exp string, want, nul w.Write([]byte(fmt.Sprintf(`{"access_token": "90d", "scope": "user", "token_type": "bearer", %s}`, exp))) })) defer ts.Close() - conf := newConf(ts.URL) + conf := newOpts(ts.URL) t1 := time.Now().Add(day) tok, _, err := conf.exchange(context.Background(), "exchange-code") t2 := t1.Add(day) @@ -265,7 +257,7 @@ func TestConfig3LO_ExchangeBadResponse(t *testing.T) { w.Write([]byte(`{"scope": "user", "token_type": "bearer"}`)) })) defer ts.Close() - conf := newConf(ts.URL) + conf := newOpts(ts.URL) _, _, err := conf.exchange(context.Background(), "code") if err == nil { t.Error("expected error from missing access_token") @@ -278,7 +270,7 @@ func TestConfig3LO_ExchangeBadResponseType(t *testing.T) { w.Write([]byte(`{"access_token":123, "scope": "user", "token_type": "bearer"}`)) })) defer ts.Close() - conf := newConf(ts.URL) + conf := newOpts(ts.URL) _, _, err := conf.exchange(context.Background(), "exchange-code") if err == nil { t.Error("expected error from non-string access_token") @@ -289,14 +281,15 @@ func TestConfig3LO_RefreshTokenReplacement(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.Write([]byte(`{"access_token":"ACCESS_TOKEN", "scope": "user", "token_type": "bearer", "refresh_token": "NEW_REFRESH_TOKEN"}`)) - return })) defer ts.Close() - conf := newConf(ts.URL) + opts := newOpts(ts.URL) const oldRefreshToken = "OLD_REFRESH_TOKEN" - tp := conf.TokenProvider(oldRefreshToken) - _, err := tp.Token(context.Background()) + tp, err := New3LOTokenProvider(oldRefreshToken, opts) if err != nil { + t.Fatal(err) + } + if _, err := tp.Token(context.Background()); err != nil { t.Errorf("got err = %v; want none", err) return } @@ -310,15 +303,17 @@ func TestConfig3LO_RefreshTokenPreservation(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.Write([]byte(`{"access_token":"ACCESS_TOKEN", "scope": "user", "token_type": "bearer"}`)) - return })) defer ts.Close() - conf := newConf(ts.URL) + opts := newOpts(ts.URL) const oldRefreshToken = "OLD_REFRESH_TOKEN" - tp := conf.TokenProvider(oldRefreshToken) - _, err := tp.Token(context.Background()) + tp, err := New3LOTokenProvider(oldRefreshToken, opts) if err != nil { - t.Fatalf("got err = %v; want none", err) + t.Fatal(err) + } + if _, err := tp.Token(context.Background()); err != nil { + t.Errorf("got err = %v; want none", err) + return } innerTP := tp.(*cachedTokenProvider).tp.(*tokenProvider3LO) if innerTP.refreshToken != oldRefreshToken { @@ -348,18 +343,23 @@ func TestConfig3LO_AuthHandlerExchangeSuccess(t *testing.T) { })) defer ts.Close() - conf := &Config3LO{ + opts := &Options3LO{ ClientID: "testClientID", Scopes: []string{"pubsub"}, AuthURL: "testAuthCodeURL", TokenURL: ts.URL, AuthStyle: StyleInHeader, + AuthHandlerOpts: &AuthorizationHandlerOptions{ + State: "testState", + Handler: authhandler, + }, } - tok, err := conf.TokenProviderWithAuthHandler(AuthenticationHandlerOptions{ - State: "testState", - AuthorizationHandler: authhandler, - }).Token(context.Background()) + tp, err := New3LOTokenProvider("", opts) + if err != nil { + t.Fatal(err) + } + tok, err := tp.Token(context.Background()) if err != nil { t.Fatal(err) } @@ -397,17 +397,21 @@ func TestConfig3LO_AuthHandlerExchangeStateMismatch(t *testing.T) { })) defer ts.Close() - conf := &Config3LO{ + opts := &Options3LO{ ClientID: "testClientID", Scopes: []string{"pubsub"}, AuthURL: "testAuthCodeURL", TokenURL: ts.URL, + AuthHandlerOpts: &AuthorizationHandlerOptions{ + State: "testState", + Handler: authhandler, + }, } - - _, err := conf.TokenProviderWithAuthHandler(AuthenticationHandlerOptions{ - State: "testState", - AuthorizationHandler: authhandler, - }).Token(context.Background()) + tp, err := New3LOTokenProvider("", opts) + if err != nil { + t.Fatal(err) + } + _, err = tp.Token(context.Background()) if wantErr := "auth: state mismatch in 3-legged-OAuth flow"; err == nil || err.Error() != wantErr { t.Errorf("err = %q; want %q", err, wantErr) } @@ -435,24 +439,28 @@ func TestConfig3LO_PKCEExchangeWithSuccess(t *testing.T) { })) defer ts.Close() - conf := &Config3LO{ + opts := &Options3LO{ ClientID: "testClientID", Scopes: []string{"pubsub"}, AuthURL: "testAuthCodeURL", TokenURL: ts.URL, AuthStyle: StyleInParams, + AuthHandlerOpts: &AuthorizationHandlerOptions{ + State: "testState", + Handler: authhandler, + PKCEConfig: &PKCEConfig{ + Challenge: "codeChallenge", + ChallengeMethod: "plain", + Verifier: "codeChallenge", + }, + }, + } + + tp, err := New3LOTokenProvider("", opts) + if err != nil { + t.Fatal(err) } - pkce := PKCEConfig{ - Challenge: "codeChallenge", - ChallengeMethod: "plain", - Verifier: "codeChallenge", - } - - tok, err := conf.TokenProviderWithAuthHandler(AuthenticationHandlerOptions{ - State: "testState", - AuthorizationHandler: authhandler, - PKCEConfig: &pkce, - }).Token(context.Background()) + tok, err := tp.Token(context.Background()) if err != nil { t.Fatal(err) } From 44f1a02f42ad497c3997aa3f40886e2ef4f18508 Mon Sep 17 00:00:00 2001 From: Cody Oss Date: Wed, 23 Aug 2023 09:44:31 -0500 Subject: [PATCH 4/5] pr feedback --- .release-please-manifest-individual.json | 2 +- auth/auth.go | 16 ++++++------- auth/auth_test.go | 30 ++++++++++++------------ auth/internal/jwt/jwt.go | 6 ++--- auth/threelegged.go | 1 + 5 files changed, 27 insertions(+), 28 deletions(-) diff --git a/.release-please-manifest-individual.json b/.release-please-manifest-individual.json index bea575724b99..d526a3e881c9 100644 --- a/.release-please-manifest-individual.json +++ b/.release-please-manifest-individual.json @@ -11,4 +11,4 @@ "pubsublite": "1.8.1", "spanner": "1.48.0", "storage": "1.32.0" -} \ No newline at end of file +} diff --git a/auth/auth.go b/auth/auth.go index 875aa23a80df..d2914d87a6d2 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -99,7 +99,7 @@ type CachedTokenProviderOptions struct { // even if it is expired. DisableAutoRefresh bool // ExpireEarly configures the amount of time before a token expires, that it - // should be refreshed. + // should be refreshed. If unset, the default value is 10 seconds. ExpireEarly time.Duration } @@ -193,7 +193,7 @@ func (e *Error) Temporary() bool { return false } sc := e.Response.StatusCode - return sc == 500 || sc == 503 || sc == 408 || sc == 429 + return sc == http.StatusInternalServerError || sc == http.StatusServiceUnavailable || sc == http.StatusRequestTimeout || sc == http.StatusTooManyRequests } func (e *Error) Unwrap() error { @@ -207,11 +207,11 @@ type Style int const ( // StyleUnknown means the value has not been initiated. Sending this in // a request will cause the token exchange to fail. - StyleUnknown Style = 0 + StyleUnknown Style = iota // StyleInParams sends client info in the body of a POST request. - StyleInParams Style = 1 + StyleInParams // StyleInHeader sends client info using Basic Authorization header. - StyleInHeader Style = 2 + StyleInHeader ) // Options2LO is the configuration settings for doing a 2-legged JWT OAuth2 flow. @@ -276,9 +276,7 @@ func (tp tokenProvider2LO) Token(ctx context.Context) (*Token, error) { Scope: strings.Join(tp.opts.Scopes, " "), Aud: tp.opts.TokenURL, AdditionalClaims: tp.opts.PrivateClaims, - } - if subject := tp.opts.Subject; subject != "" { - claimSet.Sub = subject + Sub: tp.opts.Subject, } if t := tp.opts.Expires; t > 0 { claimSet.Exp = time.Now().Add(t).Unix() @@ -304,7 +302,7 @@ func (tp tokenProvider2LO) Token(ctx context.Context) (*Token, error) { if err != nil { return nil, fmt.Errorf("auth: cannot fetch token: %w", err) } - if c := resp.StatusCode; c < 200 || c > 299 { + if c := resp.StatusCode; c < http.StatusOK || c >= http.StatusMultipleChoices { return nil, &Error{ Response: resp, Body: body, diff --git a/auth/auth_test.go b/auth/auth_test.go index cdb3819b2d4d..21453442ce1e 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -65,27 +65,27 @@ func TestError_Temporary(t *testing.T) { }{ { name: "temporary with 500", - code: 500, + code: http.StatusInternalServerError, want: true, }, { name: "temporary with 503", - code: 503, + code: http.StatusServiceUnavailable, want: true, }, { name: "temporary with 408", - code: 408, + code: http.StatusRequestTimeout, want: true, }, { name: "temporary with 429", - code: 429, + code: http.StatusTooManyRequests, want: true, }, { name: "temporary with 418", - code: 418, + code: http.StatusTeapot, want: false, }, } @@ -148,14 +148,14 @@ func TestError_Error(t *testing.T) { { name: "basic", Response: &http.Response{ - StatusCode: 418, + StatusCode: http.StatusTeapot, }, Body: []byte("I'm a teapot"), want: "auth: cannot fetch token: 418\nResponse: I'm a teapot", }, { name: "from query", - code: "418", + code: fmt.Sprint(http.StatusTeapot), description: "I'm a teapot", uri: "somewhere", want: "auth: \"418\" \"I'm a teapot\" \"somewhere\"", @@ -191,7 +191,7 @@ func TestConfigJWT2LO_JSONResponse(t *testing.T) { defer ts.Close() opts := &Options2LO{ - Email: "aaa@xxx.com", + Email: "aaa@example.com", PrivateKey: fakePrivateKey, TokenURL: ts.URL, } @@ -229,7 +229,7 @@ func TestConfigJWT2LO_BadResponse(t *testing.T) { defer ts.Close() opts := &Options2LO{ - Email: "aaa@xxx.com", + Email: "aaa@example.com", PrivateKey: fakePrivateKey, TokenURL: ts.URL, } @@ -266,7 +266,7 @@ func TestConfigJWT2LO_BadResponseType(t *testing.T) { })) defer ts.Close() opts := &Options2LO{ - Email: "aaa@xxx.com", + Email: "aaa@example.com", PrivateKey: fakePrivateKey, TokenURL: ts.URL, } @@ -300,7 +300,7 @@ func TestConfigJWT2LO_Assertion(t *testing.T) { defer ts.Close() opts := &Options2LO{ - Email: "aaa@xxx.com", + Email: "aaa@example.com", PrivateKey: fakePrivateKey, PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ", TokenURL: ts.URL, @@ -357,20 +357,20 @@ func TestConfigJWT2LO_AssertionPayload(t *testing.T) { for _, opts := range []*Options2LO{ { - Email: "aaa1@xxx.com", + Email: "aaa1@example.com", PrivateKey: fakePrivateKey, PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ", TokenURL: ts.URL, }, { - Email: "aaa2@xxx.com", + Email: "aaa2@example.com", PrivateKey: fakePrivateKey, PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ", TokenURL: ts.URL, Audience: "https://example.com", }, { - Email: "aaa2@xxx.com", + Email: "aaa2@example.com", PrivateKey: fakePrivateKey, PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ", TokenURL: ts.URL, @@ -445,7 +445,7 @@ func TestConfigJWT2LO_TokenError(t *testing.T) { defer ts.Close() opts := &Options2LO{ - Email: "aaa@xxx.com", + Email: "aaa@example.com", PrivateKey: fakePrivateKey, TokenURL: ts.URL, } diff --git a/auth/internal/jwt/jwt.go b/auth/internal/jwt/jwt.go index 521ecdfe33f1..dc28b3c3bb54 100644 --- a/auth/internal/jwt/jwt.go +++ b/auth/internal/jwt/jwt.go @@ -58,9 +58,9 @@ type Claims struct { Iss string `json:"iss"` // Scope is the scope JWT claim. Scope string `json:"scope,omitempty"` - // Exp is the expiry JWT claim. + // Exp is the expiry JWT claim. If unset, default is in one hour from now. Exp int64 `json:"exp"` - // Iat is the subject issued at claim. + // Iat is the subject issued at claim. If unset, default is now. Iat int64 `json:"iat"` // Aud is the audience JWT claim. Optional. Aud string `json:"aud"` @@ -95,7 +95,7 @@ func (c *Claims) encode() (string, error) { // Marshal private claim set and then append it to b. prv, err := json.Marshal(c.AdditionalClaims) if err != nil { - return "", fmt.Errorf("invalid map of additional claims %v", c.AdditionalClaims) + return "", fmt.Errorf("invalid map of additional claims %v: %w", c.AdditionalClaims, err) } // Concatenate public and private claim JSON objects. diff --git a/auth/threelegged.go b/auth/threelegged.go index f14b1eb0c452..94c28b4d5b2f 100644 --- a/auth/threelegged.go +++ b/auth/threelegged.go @@ -284,6 +284,7 @@ func fetchToken(ctx context.Context, c *Options3LO, v url.Values) (*Token, strin } var token *Token + // errors ignored because of default switch on content content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type")) switch content { case "application/x-www-form-urlencoded", "text/plain": From 1f00148cbbec977b675930d9be4ab58543d256da Mon Sep 17 00:00:00 2001 From: Cody Oss Date: Wed, 23 Aug 2023 12:39:24 -0500 Subject: [PATCH 5/5] pr feedback --- auth/README.md | 2 +- auth/auth.go | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/auth/README.md b/auth/README.md index 44735e0d0526..36de276a0743 100644 --- a/auth/README.md +++ b/auth/README.md @@ -1,4 +1,4 @@ # auth This module is currently EXPERIMENTAL and under active development. It is not -yet indented to be used. +yet intended to be used. diff --git a/auth/auth.go b/auth/auth.go index d2914d87a6d2..58acd93b12ce 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -118,7 +118,9 @@ func (ctpo *CachedTokenProviderOptions) expireEarly() time.Duration { } // NewCachedTokenProvider wraps a [TokenProvider] to cache the tokens returned -// by the underlying provider. +// by the underlying provider. By default it will refresh tokens ten seconds +// before they expire, but this time can be configured with the optional +// options. func NewCachedTokenProvider(tp TokenProvider, opts *CachedTokenProviderOptions) TokenProvider { if ctp, ok := tp.(*cachedTokenProvider); ok { return ctp @@ -200,7 +202,7 @@ func (e *Error) Unwrap() error { return e.Err } -// Style describes how the token endpoint wants receive the ClientID and +// Style describes how the token endpoint wants to receive the ClientID and // ClientSecret. type Style int