Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jwks: refactor RemoteKeySet cache to a map #293

Open
wants to merge 1 commit into
base: v3
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 35 additions & 21 deletions oidc/jwks.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ type RemoteKeySet struct {
inflight *inflight

// A set of cached keys.
cachedKeys []jose.JSONWebKey
cachedKeys map[string]jose.JSONWebKey
}

// inflight is used to wait on some in-flight request from multiple goroutines.
type inflight struct {
doneCh chan struct{}

keys []jose.JSONWebKey
keys map[string]jose.JSONWebKey
err error
}

Expand All @@ -70,14 +70,14 @@ func (i *inflight) wait() <-chan struct{} {
// done can only be called by a single goroutine. It records the result of the
// inflight request and signals other goroutines that the result is safe to
// inspect.
func (i *inflight) done(keys []jose.JSONWebKey, err error) {
func (i *inflight) done(keys map[string]jose.JSONWebKey, err error) {
i.keys = keys
i.err = err
close(i.doneCh)
}

// result cannot be called until the wait() channel has returned a value.
func (i *inflight) result() ([]jose.JSONWebKey, error) {
func (i *inflight) result() (map[string]jose.JSONWebKey, error) {
return i.keys, i.err
}

Expand All @@ -102,43 +102,53 @@ func (r *RemoteKeySet) verify(ctx context.Context, jws *jose.JSONWebSignature) (
break
}

keys := r.keysFromCache()
for _, key := range keys {
if keyID == "" || key.KeyID == keyID {
if payload, err := jws.Verify(&key); err == nil {
return payload, nil
}
}
if payload, ok := r.verifyWithKey(jws, keyID); ok {
return payload, nil
}

// If the kid doesn't match, check for new keys from the remote. This is the
// strategy recommended by the spec.
//
// https://openid.net/specs/openid-connect-core-1_0.html#RotateSigKeys
keys, err := r.keysFromRemote(ctx)
_, err := r.keysFromRemote(ctx)
if err != nil {
return nil, fmt.Errorf("fetching keys %v", err)
}

for _, key := range keys {
if keyID == "" || key.KeyID == keyID {
if payload, ok := r.verifyWithKey(jws, keyID); ok {
return payload, nil
}

return nil, errors.New("failed to verify id token signature")
}

// verifyWithKey attempts to verify the jws using the key with keyID from the cache
// if keyID is the empty string, it tries each key in the cache
func (r *RemoteKeySet) verifyWithKey(jws *jose.JSONWebSignature, keyID string) (payload []byte, ok bool) {
if keyID == "" {
for _, key := range r.keysFromCache() {
if payload, err := jws.Verify(&key); err == nil {
return payload, nil
return payload, true
}
}
} else {
if key, ok := r.keysFromCache()[keyID]; ok {
if payload, err := jws.Verify(&key); err == nil {
return payload, true
}
}
}
return nil, errors.New("failed to verify id token signature")
return nil, false
}

func (r *RemoteKeySet) keysFromCache() (keys []jose.JSONWebKey) {
func (r *RemoteKeySet) keysFromCache() (keys map[string]jose.JSONWebKey) {
r.mu.Lock()
defer r.mu.Unlock()
return r.cachedKeys
}

// keysFromRemote syncs the key set from the remote set, records the values in the
// cache, and returns the key set.
func (r *RemoteKeySet) keysFromRemote(ctx context.Context) ([]jose.JSONWebKey, error) {
func (r *RemoteKeySet) keysFromRemote(ctx context.Context) (map[string]jose.JSONWebKey, error) {
// Need to lock to inspect the inflight request field.
r.mu.Lock()
// If there's not a current inflight request, create one.
Expand Down Expand Up @@ -178,7 +188,7 @@ func (r *RemoteKeySet) keysFromRemote(ctx context.Context) ([]jose.JSONWebKey, e
}
}

func (r *RemoteKeySet) updateKeys() ([]jose.JSONWebKey, error) {
func (r *RemoteKeySet) updateKeys() (map[string]jose.JSONWebKey, error) {
req, err := http.NewRequest("GET", r.jwksURL, nil)
if err != nil {
return nil, fmt.Errorf("oidc: can't create request: %v", err)
Expand All @@ -204,5 +214,9 @@ func (r *RemoteKeySet) updateKeys() ([]jose.JSONWebKey, error) {
if err != nil {
return nil, fmt.Errorf("oidc: failed to decode keys: %v %s", err, body)
}
return keySet.Keys, nil
keys := make(map[string]jose.JSONWebKey)
for _, key := range keySet.Keys {
keys[key.KeyID] = key
}
return keys, nil
}