Skip to content

Commit

Permalink
fix: cpu contention when reading JWKs and suppress generating duplica…
Browse files Browse the repository at this point in the history
…te JWKs (#3870)

Previously each concurrent caller would need to lock a shared mutex when reading or writing a given JWK set.
The read path now doesn't require locking a mutex at all and instead returns valid query results directly.

The write path is now protected by a concurrency control mechanism (using x/sync/singleflight) to ensure only one JWK set is generated and persisted.

Note: Duplicate JWK sets may still be improperly generated if running more than one Hydra instance in a high traffic environment.
  • Loading branch information
terev authored Nov 4, 2024
1 parent f777fd1 commit d5f65c5
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 63 deletions.
2 changes: 1 addition & 1 deletion cmd/server/helper_cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func GetOrCreateTLSCertificate(ctx context.Context, d driver.Registry, iface con
}

// no certificates configured: self-sign a new cert
priv, err := jwk.GetOrGenerateKeys(ctx, d, d.SoftwareKeyManager(), TlsKeyName, uuid.Must(uuid.NewV4()).String(), "RS256")
priv, err := jwk.GetOrGenerateKeySetPrivateKey(ctx, d, d.SoftwareKeyManager(), TlsKeyName, uuid.Must(uuid.NewV4()).String(), "RS256")
if err != nil {
d.Logger().WithError(err).Fatal("Unable to fetch or generate HTTPS TLS key pair")
return nil // in case Fatal is hooked
Expand Down
13 changes: 3 additions & 10 deletions jwk/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"github.com/ory/x/httprouterx"

"github.com/gofrs/uuid"
"github.com/pkg/errors"

"github.com/ory/x/urlx"

Expand Down Expand Up @@ -101,17 +100,11 @@ func (h *Handler) discoverJsonWebKeys(w http.ResponseWriter, r *http.Request) {
for _, set := range wellKnownKeys {
set := set
eg.Go(func() error {
k, err := h.r.KeyManager().GetKeySet(ctx, set)
if errors.Is(err, x.ErrNotFound) {
h.r.Logger().Warnf("JSON Web Key Set %q does not exist yet, generating new key pair...", set)
k, err = h.r.KeyManager().GenerateAndPersistKeySet(ctx, set, uuid.Must(uuid.NewV4()).String(), string(jose.RS256), "sig")
if err != nil {
return err
}
} else if err != nil {
keySet, err := GetOrGenerateKeySet(ctx, h.r, h.r.KeyManager(), set, uuid.Must(uuid.NewV4()).String(), string(jose.RS256))
if err != nil {
return err
}
keys <- ExcludePrivateKeys(k)
keys <- ExcludePrivateKeys(keySet)
return nil
})
}
Expand Down
78 changes: 38 additions & 40 deletions jwk/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,69 +12,67 @@ import (
"crypto/x509"
"encoding/json"
"encoding/pem"
"sync"

"golang.org/x/sync/singleflight"

hydra "github.com/ory/hydra-client-go/v2"
"github.com/ory/hydra/v2/x"

"github.com/ory/x/josex"

"github.com/ory/x/errorsx"

"github.com/ory/hydra/v2/x"

jose "github.com/go-jose/go-jose/v3"
"github.com/pkg/errors"
)

var mapLock sync.RWMutex
var locks = map[string]*sync.RWMutex{}

func getLock(set string) *sync.RWMutex {
mapLock.Lock()
defer mapLock.Unlock()
if _, ok := locks[set]; !ok {
locks[set] = new(sync.RWMutex)
}
return locks[set]
}
var jwkGenFlightGroup singleflight.Group

func EnsureAsymmetricKeypairExists(ctx context.Context, r InternalRegistry, alg, set string) error {
_, err := GetOrGenerateKeys(ctx, r, r.KeyManager(), set, set, alg)
_, err := GetOrGenerateKeySetPrivateKey(ctx, r, r.KeyManager(), set, set, alg)
return err
}

func GetOrGenerateKeys(ctx context.Context, r InternalRegistry, m Manager, set, kid, alg string) (private *jose.JSONWebKey, err error) {
getLock(set).Lock()
defer getLock(set).Unlock()

keys, err := m.GetKeySet(ctx, set)
if errors.Is(err, x.ErrNotFound) || keys != nil && len(keys.Keys) == 0 {
r.Logger().Warnf("JSON Web Key Set \"%s\" does not exist yet, generating new key pair...", set)
keys, err = m.GenerateAndPersistKeySet(ctx, set, kid, alg, "sig")
if err != nil {
return nil, err
}
} else if err != nil {
func GetOrGenerateKeySetPrivateKey(ctx context.Context, r InternalRegistry, m Manager, set, kid, alg string) (private *jose.JSONWebKey, err error) {
keySet, err := GetOrGenerateKeySet(ctx, r, m, set, kid, alg)
if err != nil {
return nil, err
}

privKey, privKeyErr := FindPrivateKey(keys)
if privKeyErr == nil {
privKey, err := FindPrivateKey(keySet)
if err == nil {
return privKey, nil
} else {
r.Logger().WithField("jwks", set).Warnf("JSON Web Key not found in JSON Web Key Set %s, generating new key pair...", set)
}

keys, err = m.GenerateAndPersistKeySet(ctx, set, kid, alg, "sig")
if err != nil {
return nil, err
}
keySet, err = generateKeySet(ctx, r, m, set, kid, alg)
if err != nil {
return nil, err
}

privKey, err := FindPrivateKey(keys)
if err != nil {
return nil, err
}
return privKey, nil
return FindPrivateKey(keySet)
}

func GetOrGenerateKeySet(ctx context.Context, r InternalRegistry, m Manager, set, kid, alg string) (*jose.JSONWebKeySet, error) {
keys, err := m.GetKeySet(ctx, set)
if err != nil && !errors.Is(err, x.ErrNotFound) {
return nil, err
} else if keys != nil && len(keys.Keys) > 0 {
return keys, nil
}

return generateKeySet(ctx, r, m, set, kid, alg)
}

func generateKeySet(ctx context.Context, r InternalRegistry, m Manager, set, kid, alg string) (*jose.JSONWebKeySet, error) {
// Suppress duplicate key set generation jobs where the set+alg match.
keysResult, err, _ := jwkGenFlightGroup.Do(set+alg, func() (any, error) {
r.Logger().WithField("jwks", set).Warnf("JSON Web Key not found in JSON Web Key Set %s, generating new key pair...", set)
return m.GenerateAndPersistKeySet(ctx, set, kid, alg, "sig")
})
if err != nil {
return nil, err
}
return keysResult.(*jose.JSONWebKeySet), nil
}

func First(keys []jose.JSONWebKey) *jose.JSONWebKey {
Expand Down
23 changes: 12 additions & 11 deletions jwk/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/ory/x/contextx"

"github.com/ory/hydra/v2/internal"
"github.com/ory/hydra/v2/jwk"
"github.com/ory/hydra/v2/x"
"github.com/ory/x/contextx"
)

type fakeSigner struct {
Expand Down Expand Up @@ -226,46 +227,46 @@ func TestGetOrGenerateKeys(t *testing.T) {
return NewMockManager(ctrl)
}

t.Run("Test_Helper/Run_GetOrGenerateKeys_With_GetKeySetError", func(t *testing.T) {
t.Run("Test_Helper/Run_GetOrGenerateKeySetPrivateKey_With_GetKeySetError", func(t *testing.T) {
keyManager := km(t)
keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(nil, errors.New("GetKeySetError"))
privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256")
privKey, err := jwk.GetOrGenerateKeySetPrivateKey(context.TODO(), reg, keyManager, setId, keyId, "RS256")
assert.Nil(t, privKey)
assert.EqualError(t, err, "GetKeySetError")
})

t.Run("Test_Helper/Run_GetOrGenerateKeys_With_GenerateAndPersistKeySetError", func(t *testing.T) {
t.Run("Test_Helper/Run_GetOrGenerateKeySetPrivateKey_With_GenerateAndPersistKeySetError", func(t *testing.T) {
keyManager := km(t)
keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(nil, errors.Wrap(x.ErrNotFound, ""))
keyManager.EXPECT().GenerateAndPersistKeySet(gomock.Any(), gomock.Eq(setId), gomock.Eq(keyId), gomock.Eq("RS256"), gomock.Eq("sig")).Return(nil, errors.New("GetKeySetError"))
privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256")
privKey, err := jwk.GetOrGenerateKeySetPrivateKey(context.TODO(), reg, keyManager, setId, keyId, "RS256")
assert.Nil(t, privKey)
assert.EqualError(t, err, "GetKeySetError")
})

t.Run("Test_Helper/Run_GetOrGenerateKeys_With_GenerateAndPersistKeySetError", func(t *testing.T) {
t.Run("Test_Helper/Run_GetOrGenerateKeySetPrivateKey_With_GenerateAndPersistKeySetError", func(t *testing.T) {
keyManager := km(t)
keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(keySetWithoutPrivateKey, nil)
keyManager.EXPECT().GenerateAndPersistKeySet(gomock.Any(), gomock.Eq(setId), gomock.Eq(keyId), gomock.Eq("RS256"), gomock.Eq("sig")).Return(nil, errors.New("GetKeySetError"))
privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256")
privKey, err := jwk.GetOrGenerateKeySetPrivateKey(context.TODO(), reg, keyManager, setId, keyId, "RS256")
assert.Nil(t, privKey)
assert.EqualError(t, err, "GetKeySetError")
})

t.Run("Test_Helper/Run_GetOrGenerateKeys_With_GetKeySet_ContainsMissingPrivateKey", func(t *testing.T) {
t.Run("Test_Helper/Run_GetOrGenerateKeySetPrivateKey_With_GetKeySet_ContainsMissingPrivateKey", func(t *testing.T) {
keyManager := km(t)
keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(keySetWithoutPrivateKey, nil)
keyManager.EXPECT().GenerateAndPersistKeySet(gomock.Any(), gomock.Eq(setId), gomock.Eq(keyId), gomock.Eq("RS256"), gomock.Eq("sig")).Return(keySet, nil)
privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256")
privKey, err := jwk.GetOrGenerateKeySetPrivateKey(context.TODO(), reg, keyManager, setId, keyId, "RS256")
assert.NoError(t, err)
assert.Equal(t, privKey, &keySet.Keys[0])
})

t.Run("Test_Helper/Run_GetOrGenerateKeys_With_GenerateAndPersistKeySet_ContainsMissingPrivateKey", func(t *testing.T) {
t.Run("Test_Helper/Run_GetOrGenerateKeySetPrivateKey_With_GenerateAndPersistKeySet_ContainsMissingPrivateKey", func(t *testing.T) {
keyManager := km(t)
keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(keySetWithoutPrivateKey, nil)
keyManager.EXPECT().GenerateAndPersistKeySet(gomock.Any(), gomock.Eq(setId), gomock.Eq(keyId), gomock.Eq("RS256"), gomock.Eq("sig")).Return(keySetWithoutPrivateKey, nil).Times(1)
privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256")
privKey, err := jwk.GetOrGenerateKeySetPrivateKey(context.TODO(), reg, keyManager, setId, keyId, "RS256")
assert.Nil(t, privKey)
assert.EqualError(t, err, "key not found")
})
Expand Down
3 changes: 2 additions & 1 deletion jwk/jwt_strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/gofrs/uuid"

"github.com/ory/fosite"

"github.com/ory/hydra/v2/driver/config"

"github.com/pkg/errors"
Expand Down Expand Up @@ -40,7 +41,7 @@ func NewDefaultJWTSigner(c *config.DefaultProvider, r InternalRegistry, setID st
}

func (j *DefaultJWTSigner) getKeys(ctx context.Context) (private *jose.JSONWebKey, err error) {
private, err = GetOrGenerateKeys(ctx, j.r, j.r.KeyManager(), j.setID, uuid.Must(uuid.NewV4()).String(), string(jose.RS256))
private, err = GetOrGenerateKeySetPrivateKey(ctx, j.r, j.r.KeyManager(), j.setID, uuid.Must(uuid.NewV4()).String(), string(jose.RS256))
if err == nil {
return private, nil
}
Expand Down

0 comments on commit d5f65c5

Please sign in to comment.