Skip to content
This repository has been archived by the owner on Jan 30, 2025. It is now read-only.

[DOM-49091] Refresh JWKS when kid not found #1

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ type JWTAuth struct {

logger *zap.Logger
parsedSignKey interface{} // can be []byte, *rsa.PublicKey, *ecdsa.PublicKey, etc.
jwkCachedSet jwk.Set
jwkCache *jwk.Cache
}

// CaddyModule implements caddy.Module interface.
Expand Down Expand Up @@ -163,10 +163,20 @@ func (ja *JWTAuth) usingJWK() bool {
}

func (ja *JWTAuth) setupJWKLoader() {
cache := jwk.NewCache(context.Background(), jwk.WithErrSink(ja))
cache.Register(ja.JWKURL)
ja.jwkCachedSet = jwk.NewCachedSet(cache, ja.JWKURL)
ja.logger.Info("using JWKs from URL", zap.String("url", ja.JWKURL), zap.Int("loaded_keys", ja.jwkCachedSet.Len()))
ja.jwkCache = jwk.NewCache(context.Background(), jwk.WithErrSink(ja))

// TODO: jwk.WithMinRefreshInterval OR jwk.WithRefreshInterval
ja.jwkCache.Register(ja.JWKURL)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
defer cancel()
if _, err := ja.jwkCache.Refresh(ctx, ja.JWKURL); err != nil {
// url is not a valid JWKS
panic(err)
}
// TODO: why are we using a CachedSet rather than a Cache here? With Cache we can explicitly call Refresh
// ja.jwkCachedSet = jwk.NewCachedSet(cache, ja.JWKURL)
set, _ := ja.jwkCache.Get(context.Background(), ja.JWKURL)
ja.logger.Info("using JWKs from URL", zap.String("url", ja.JWKURL), zap.Int("loaded_keys", set.Len()))
}

// Validate implements caddy.Validator interface.
Expand Down Expand Up @@ -211,12 +221,22 @@ func (ja *JWTAuth) keyProvider() jws.KeyProviderFunc {
return func(_ context.Context, sink jws.KeySink, sig *jws.Signature, _ *jws.Message) error {
if ja.usingJWK() {
kid := sig.ProtectedHeaders().KeyID()
key, found := ja.jwkCachedSet.LookupKeyID(kid)

set, _ := ja.jwkCache.Get(context.Background(), ja.JWKURL)
key, found := set.LookupKeyID(kid)
if !found {
if kid == "" {
return fmt.Errorf("missing kid in JWT header")
}
return fmt.Errorf("key specified by kid %q not found in JWKs", kid)
// TODO: refresh the cache if the kid isn't found and try again
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
defer cancel()

_, _ = ja.jwkCache.Refresh(ctx, ja.JWKURL)
key, found = set.LookupKeyID(kid)
if !found {
return fmt.Errorf("key specified by kid %q not found in JWKs", kid)
}
}
sink.Key(ja.determineSigningAlgorithm(key.Algorithm()), key)
} else {
Expand Down
10 changes: 7 additions & 3 deletions jwt_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package caddyjwt

import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
Expand Down Expand Up @@ -637,7 +638,8 @@ func TestJWK(t *testing.T) {
time.Sleep(3 * time.Second)
ja := &JWTAuth{JWKURL: TestJWKURL, logger: testLogger}
assert.Nil(t, ja.Validate())
assert.Equal(t, 1, ja.jwkCachedSet.Len())
set, _ := ja.jwkCache.Get(context.Background(), ja.JWKURL)
assert.Equal(t, 1, set.Len())

token := issueTokenStringJWK(MapClaims{"sub": "ggicci"})
rw := httptest.NewRecorder()
Expand All @@ -653,7 +655,8 @@ func TestJWKSet(t *testing.T) {
time.Sleep(3 * time.Second)
ja := &JWTAuth{JWKURL: TestJWKSetURL, logger: testLogger}
assert.Nil(t, ja.Validate())
assert.Equal(t, 2, ja.jwkCachedSet.Len())
set, _ := ja.jwkCache.Get(context.Background(), ja.JWKURL)
assert.Equal(t, 2, set.Len())

token := issueTokenStringJWK(MapClaims{"sub": "ggicci"})
rw := httptest.NewRecorder()
Expand All @@ -669,7 +672,8 @@ func TestJWKSet_KeyNotFound(t *testing.T) {
time.Sleep(3 * time.Second)
ja := &JWTAuth{JWKURL: TestJWKSetURLInapplicable, logger: testLogger}
assert.Nil(t, ja.Validate())
assert.Equal(t, 2, ja.jwkCachedSet.Len())
set, _ := ja.jwkCache.Get(context.Background(), ja.JWKURL)
assert.Equal(t, 2, set.Len())

token := issueTokenStringJWK(MapClaims{"sub": "ggicci"})
rw := httptest.NewRecorder()
Expand Down