Skip to content

Commit

Permalink
Use local variable to reduce read locks
Browse files Browse the repository at this point in the history
  • Loading branch information
Johannes Koch committed Nov 26, 2021
1 parent ff15efb commit 660a635
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 20 deletions.
38 changes: 20 additions & 18 deletions accesscontrol/jwks.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,22 @@ func NewJWKS(uri string, ttl string, transport http.RoundTripper, confContext co
}

func (j *JWKS) GetKeys(kid string) ([]JWK, error) {
var keys []JWK
var (
keys []JWK
err error
)

j.mtx.RLock()
lKeys := len(j.Keys)
allKeys := j.Keys
j.mtx.RUnlock()
if lKeys == 0 || j.hasExpired() {
if err := j.Load(); err != nil {
if len(allKeys) == 0 || j.hasExpired() {
allKeys, err = j.Load()
if err != nil {
return keys, fmt.Errorf("error loading JWKS: %v", err)
}
}

j.mtx.RLock()
ks := j.Keys
j.mtx.RUnlock()
for _, key := range ks {
for _, key := range allKeys {
if key.KeyID == kid {
keys = append(keys, key)
}
Expand All @@ -87,55 +88,56 @@ func (j *JWKS) GetKey(kid string, alg string, use string) (*JWK, error) {
return nil, nil
}

func (j *JWKS) Load() error {
func (j *JWKS) Load() ([]JWK, error) {
var rawJSON []byte

if j.file != "" {
j, err := reader.ReadFromFile("jwks_url", j.file)
if err != nil {
return err
return nil, err
}
rawJSON = j
} else if j.transport != nil {
req, err := http.NewRequest("GET", "", nil)
if err != nil {
return err
return nil, err
}
ctx := context.WithValue(j.context, request.URLAttribute, j.uri)
// TODO which roundtrip name?
ctx = context.WithValue(ctx, request.RoundTripName, "jwks")
req = req.WithContext(ctx)
response, err := j.transport.RoundTrip(req)
if err != nil {
return err
return nil, err
}
if response.StatusCode != 200 {
return fmt.Errorf("status code %d", response.StatusCode)
return nil, fmt.Errorf("status code %d", response.StatusCode)
}

defer response.Body.Close()

body, err := ioutil.ReadAll(response.Body)
if err != nil {
return fmt.Errorf("error reading JWKS response for %q: %v", j.uri, err)
return nil, fmt.Errorf("error reading JWKS response for %q: %v", j.uri, err)
}
rawJSON = body
} else {
return fmt.Errorf("jwks: missing both file and request")
return nil, fmt.Errorf("jwks: missing both file and request")
}

var jwks JWKS
err := json.Unmarshal(rawJSON, &jwks)
if err != nil {
return err
return nil, err
}

j.mtx.Lock()
defer j.mtx.Unlock()

j.Keys = jwks.Keys
j.expiry = time.Now().Unix() + int64(j.ttl.Seconds())
j.mtx.Unlock()

return nil
return j.Keys, nil
}

func (jwks *JWKS) hasExpired() bool {
Expand Down
4 changes: 2 additions & 2 deletions accesscontrol/jwks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func Test_JWKS_Load(t *testing.T) {
t.Run(tt.name, func(subT *testing.T) {
jwks, err := ac.NewJWKS("file:"+tt.file, "", nil, nil)
helper.Must(err)
err = jwks.Load()
_, err = jwks.Load()
if err != nil && tt.expParsed {
subT.Error("no jwks parsed")
}
Expand Down Expand Up @@ -92,7 +92,7 @@ func Test_JWKS_GetKey(t *testing.T) {
helper := test.New(subT)
jwks, err := ac.NewJWKS("file:"+tt.file, "", nil, nil)
helper.Must(err)
err = jwks.Load()
_, err = jwks.Load()
helper.Must(err)
jwk, err := jwks.GetKey(tt.kid, tt.alg, tt.use)
if jwk == nil && tt.expFound {
Expand Down

0 comments on commit 660a635

Please sign in to comment.