diff --git a/go.mod b/go.mod index 2272464cb81..6e100d97d0d 100644 --- a/go.mod +++ b/go.mod @@ -69,6 +69,7 @@ require ( golang.org/x/crypto v0.10.0 golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 golang.org/x/oauth2 v0.9.0 + golang.org/x/sync v0.3.0 golang.org/x/tools v0.10.0 ) @@ -232,7 +233,6 @@ require ( go.opentelemetry.io/proto/otlp v0.18.0 // indirect golang.org/x/mod v0.11.0 // indirect golang.org/x/net v0.11.0 // indirect - golang.org/x/sync v0.3.0 // indirect golang.org/x/sys v0.9.0 // indirect golang.org/x/text v0.10.0 // indirect golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect diff --git a/jwk/handler.go b/jwk/handler.go index 1442b55f830..89e3f92086c 100644 --- a/jwk/handler.go +++ b/jwk/handler.go @@ -7,6 +7,8 @@ import ( "encoding/json" "net/http" + "golang.org/x/sync/errgroup" + "github.com/ory/x/httprouterx" "github.com/gofrs/uuid" @@ -89,25 +91,34 @@ func (h *Handler) SetRoutes(admin *httprouterx.RouterAdmin, public *httprouterx. // 200: jsonWebKeySet // default: errorOAuth2 func (h *Handler) discoverJsonWebKeys(w http.ResponseWriter, r *http.Request) { - var jwks jose.JSONWebKeySet - - ctx := r.Context() - for _, set := range stringslice.Unique(h.r.Config().WellKnownKeys(ctx)) { - keys, err := h.r.KeyManager().GetKeySet(ctx, set) - if errors.Is(err, x.ErrNotFound) { - h.r.Logger().Warnf("JSON Web Key Set \"%s\" does not exist yet, generating new key pair...", set) - keys, err = h.r.KeyManager().GenerateAndPersistKeySet(ctx, set, uuid.Must(uuid.NewV4()).String(), string(jose.RS256), "sig") - if err != nil { - h.r.Writer().WriteError(w, r, err) - return + eg, ctx := errgroup.WithContext(r.Context()) + wellKnownKeys := stringslice.Unique(h.r.Config().WellKnownKeys(ctx)) + keys := make(chan *jose.JSONWebKeySet, len(wellKnownKeys)) + for _, set := range wellKnownKeys { + set := set + eg.Go(func() error { + k, err := h.r.KeyManager().GetKeySet(ctx, set) + if errors.Is(err, x.ErrNotFound) { + h.r.Logger().Warnf("JSON Web Key Set %q does not exist yet, generating new key pair...", set) + k, err = h.r.KeyManager().GenerateAndPersistKeySet(ctx, set, uuid.Must(uuid.NewV4()).String(), string(jose.RS256), "sig") + if err != nil { + return err + } + } else if err != nil { + return err } - } else if err != nil { - h.r.Writer().WriteError(w, r, err) - return - } - - keys = ExcludePrivateKeys(keys) - jwks.Keys = append(jwks.Keys, keys.Keys...) + keys <- ExcludePrivateKeys(k) + return nil + }) + } + if err := eg.Wait(); err != nil { + h.r.Writer().WriteError(w, r, err) + return + } + close(keys) + var jwks jose.JSONWebKeySet + for k := range keys { + jwks.Keys = append(jwks.Keys, k.Keys...) } h.r.Writer().Write(w, r, &jwks)