diff --git a/oidc/jwks.go b/oidc/jwks.go index 3784cf29..cbe2d2d3 100644 --- a/oidc/jwks.go +++ b/oidc/jwks.go @@ -46,14 +46,14 @@ type RemoteKeySet struct { inflight *inflight // A set of cached keys. - cachedKeys []jose.JSONWebKey + cachedKeys map[string]jose.JSONWebKey } // inflight is used to wait on some in-flight request from multiple goroutines. type inflight struct { doneCh chan struct{} - keys []jose.JSONWebKey + keys map[string]jose.JSONWebKey err error } @@ -70,14 +70,14 @@ func (i *inflight) wait() <-chan struct{} { // done can only be called by a single goroutine. It records the result of the // inflight request and signals other goroutines that the result is safe to // inspect. -func (i *inflight) done(keys []jose.JSONWebKey, err error) { +func (i *inflight) done(keys map[string]jose.JSONWebKey, err error) { i.keys = keys i.err = err close(i.doneCh) } // result cannot be called until the wait() channel has returned a value. -func (i *inflight) result() ([]jose.JSONWebKey, error) { +func (i *inflight) result() (map[string]jose.JSONWebKey, error) { return i.keys, i.err } @@ -102,35 +102,45 @@ func (r *RemoteKeySet) verify(ctx context.Context, jws *jose.JSONWebSignature) ( break } - keys := r.keysFromCache() - for _, key := range keys { - if keyID == "" || key.KeyID == keyID { - if payload, err := jws.Verify(&key); err == nil { - return payload, nil - } - } + if payload, ok := r.verifyWithKey(jws, keyID); ok { + return payload, nil } - // If the kid doesn't match, check for new keys from the remote. This is the // strategy recommended by the spec. // // https://openid.net/specs/openid-connect-core-1_0.html#RotateSigKeys - keys, err := r.keysFromRemote(ctx) + _, err := r.keysFromRemote(ctx) if err != nil { return nil, fmt.Errorf("fetching keys %v", err) } - for _, key := range keys { - if keyID == "" || key.KeyID == keyID { + if payload, ok := r.verifyWithKey(jws, keyID); ok { + return payload, nil + } + + return nil, errors.New("failed to verify id token signature") +} + +// verifyWithKey attempts to verify the jws using the key with keyID from the cache +// if keyID is the empty string, it tries each key in the cache +func (r *RemoteKeySet) verifyWithKey(jws *jose.JSONWebSignature, keyID string) (payload []byte, ok bool) { + if keyID == "" { + for _, key := range r.keysFromCache() { if payload, err := jws.Verify(&key); err == nil { - return payload, nil + return payload, true + } + } + } else { + if key, ok := r.keysFromCache()[keyID]; ok { + if payload, err := jws.Verify(&key); err == nil { + return payload, true } } } - return nil, errors.New("failed to verify id token signature") + return nil, false } -func (r *RemoteKeySet) keysFromCache() (keys []jose.JSONWebKey) { +func (r *RemoteKeySet) keysFromCache() (keys map[string]jose.JSONWebKey) { r.mu.Lock() defer r.mu.Unlock() return r.cachedKeys @@ -138,7 +148,7 @@ func (r *RemoteKeySet) keysFromCache() (keys []jose.JSONWebKey) { // keysFromRemote syncs the key set from the remote set, records the values in the // cache, and returns the key set. -func (r *RemoteKeySet) keysFromRemote(ctx context.Context) ([]jose.JSONWebKey, error) { +func (r *RemoteKeySet) keysFromRemote(ctx context.Context) (map[string]jose.JSONWebKey, error) { // Need to lock to inspect the inflight request field. r.mu.Lock() // If there's not a current inflight request, create one. @@ -178,7 +188,7 @@ func (r *RemoteKeySet) keysFromRemote(ctx context.Context) ([]jose.JSONWebKey, e } } -func (r *RemoteKeySet) updateKeys() ([]jose.JSONWebKey, error) { +func (r *RemoteKeySet) updateKeys() (map[string]jose.JSONWebKey, error) { req, err := http.NewRequest("GET", r.jwksURL, nil) if err != nil { return nil, fmt.Errorf("oidc: can't create request: %v", err) @@ -204,5 +214,9 @@ func (r *RemoteKeySet) updateKeys() ([]jose.JSONWebKey, error) { if err != nil { return nil, fmt.Errorf("oidc: failed to decode keys: %v %s", err, body) } - return keySet.Keys, nil + keys := make(map[string]jose.JSONWebKey) + for _, key := range keySet.Keys { + keys[key.KeyID] = key + } + return keys, nil }