diff --git a/.circleci/config.yml b/.circleci/config.yml index 2fb42655fb9..83bc06aee8f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -69,12 +69,49 @@ jobs: - golangci/install - golangci/lint - run: make .bin/go-acc - - run: .bin/go-acc -o coverage.out ./... -- -failfast -timeout=20m -tags sqlite + - run: .bin/go-acc -o coverage.out ./... -- -failfast -timeout=20m -tags=sqlite,hsm # Running race conditions requires parallel tests, otherwise it's worthless (which is the case) # - run: go test -race -short $(go list ./... | grep -v cmd) - run: | bash <(curl -s https://codecov.io/bash) + test-hsm: + docker: + - image: cimg/go:1.16-node + environment: + - HSM_ENABLED=true + - HSM_LIBRARY=/usr/lib/softhsm/libsofthsm2.so + - HSM_TOKEN_LABEL=hydra + - HSM_PIN=1234 + - TEST_DATABASE_POSTGRESQL=postgres://test:test@localhost:5432/postgres?sslmode=disable + - TEST_DATABASE_MYSQL=mysql://root:test@(localhost:3306)/mysql?multiStatements=true&parseTime=true + - TEST_DATABASE_COCKROACHDB=cockroach://root@localhost:26257/defaultdb?sslmode=disable + - image: postgres:9.6 + environment: + - POSTGRES_USER=test + - POSTGRES_PASSWORD=test + - POSTGRES_DB=postgres + - image: mysql:8.0 + environment: + - MYSQL_ROOT_PASSWORD=test + - image: cockroachdb/cockroach:v20.2.5 + command: start-single-node --insecure + steps: + - checkout + - setup_remote_docker + + - go/load-cache: + key: ory-hydra-go-mod-v1 + - go/mod-download + - go/save-cache: + key: ory-hydra-go-mod-v1 + + - run: sudo apt update + - run: sudo apt install -y softhsm opensc + - run: sudo rm -rf /var/lib/softhsm/tokens; sudo mkdir -p /var/lib/softhsm/tokens; sudo chmod -R a+rwx /var/lib/softhsm; sudo chmod a+rx /etc/softhsm; sudo chmod a+r /etc/softhsm/* + - run: pkcs11-tool --module /usr/lib/softhsm/libsofthsm2.so --slot 0 --init-token --so-pin 0000 --init-pin --pin 1234 --label hydra + - run: go test -p 1 -v -failfast -short -timeout=20m -tags=sqlite,hsm ./... + test-e2e: docker: - image: oryd/e2e-env:latest @@ -158,6 +195,10 @@ workflows: filters: tags: only: /.*/ + - test-hsm: + filters: + tags: + only: /.*/ - test-e2e: filters: tags: @@ -165,6 +206,7 @@ workflows: - changelog/generate: requires: - test + - test-hsm - test-e2e - docs/build # - test-legacy-migrations-mysql @@ -180,6 +222,7 @@ workflows: specignorepgks: internal/httpclient gopkg.in/square/go-jose.v2 requires: - test + - test-hsm - test-e2e # - test-legacy-migrations-mysql # - test-legacy-migrations-cockroach @@ -193,6 +236,7 @@ workflows: swagpath: spec/api.json requires: - test + - test-hsm - sdk/generate - goreleaser/release - docs/build @@ -217,6 +261,7 @@ workflows: - goreleaser/release: requires: - test + - test-hsm - test-e2e - changelog/generate # - test-legacy-migrations-mysql diff --git a/.docker/Dockerfile-hsm b/.docker/Dockerfile-hsm new file mode 100644 index 00000000000..4b46f082cc8 --- /dev/null +++ b/.docker/Dockerfile-hsm @@ -0,0 +1,57 @@ +FROM golang:1.16-alpine AS builder + +RUN apk -U --no-cache add build-base git gcc bash + +WORKDIR /go/src/github.com/ory/hydra + +ADD go.mod go.mod +ADD go.sum go.sum + +ENV GO111MODULE on +ENV CGO_ENABLED 1 + +RUN go mod download + +ADD . . + +FROM builder as build-hydra +RUN go build -tags=sqlite,hsm -o /usr/bin/hydra + +FROM builder as test-hsm +ENV HSM_ENABLED=true +ENV HSM_LIBRARY=/usr/lib/softhsm/libsofthsm2.so +ENV HSM_TOKEN_LABEL=hydra +ENV HSM_PIN=1234 + +RUN apk -U --no-cache add softhsm opensc; \ + pkcs11-tool --module /usr/lib/softhsm/libsofthsm2.so --slot 0 --init-token --so-pin 0000 --init-pin --pin 1234 --label hydra; \ + go test -p 1 -v -failfast -short -tags=sqlite,hsm ./... + +FROM alpine:3.14.2 + +RUN apk -U --no-cache add softhsm opensc; \ + pkcs11-tool --module /usr/lib/softhsm/libsofthsm2.so --slot 0 --init-token --so-pin 0000 --init-pin --pin 1234 --label hydra + +RUN addgroup -S ory; \ + adduser -S ory -G ory -D -h /home/ory -s /bin/nologin; \ + chown -R ory:ory /home/ory; \ + chown -R ory:ory /var/lib/softhsm/tokens + +COPY --from=build-hydra /usr/bin/hydra /usr/bin/hydra + +# By creating the sqlite folder as the ory user, the mounted volume will be owned by ory:ory, which +# is required for read/write of SQLite. +RUN mkdir -p /var/lib/sqlite +RUN chown ory:ory /var/lib/sqlite +VOLUME /var/lib/sqlite + +# Exposing the ory home directory +VOLUME /home/ory + +# Declare the standard ports used by hydra (4433 for public service endpoint, 4434 for admin service endpoint) +EXPOSE 4444 4445 + +USER ory + +ENTRYPOINT ["hydra"] +CMD ["serve"] diff --git a/.dockerignore b/.dockerignore index c437c041195..4d913fbbc91 100644 --- a/.dockerignore +++ b/.dockerignore @@ -13,5 +13,4 @@ dist .bin test/e2e test/mock-* -test/stub cypress diff --git a/.goreleaser.yml b/.goreleaser.yml index 3f890585cbe..490f61550f7 100644 --- a/.goreleaser.yml +++ b/.goreleaser.yml @@ -14,8 +14,7 @@ builds: - id: hydra-sqlite-darwin flags: - - -tags - - sqlite + - -tags=sqlite,hsm ldflags: - -s -w -X github.com/ory/hydra/driver/config.Version={{.Tag}} -X github.com/ory/hydra/driver/config.Commit={{.FullCommit}} -X github.com/ory/hydra/driver/config.Date={{.Date}} # - "-extldflags '-static'" @@ -32,8 +31,7 @@ builds: - id: hydra-sqlite-linux flags: - - -tags - - sqlite + - -tags=sqlite,hsm ldflags: - -s -w -X github.com/ory/hydra/driver/config.Version={{.Tag}} -X github.com/ory/hydra/driver/config.Commit={{.FullCommit}} -X github.com/ory/hydra/driver/config.Date={{.Date}} binary: hydra @@ -46,8 +44,7 @@ builds: - id: hydra-sqlite-linux-libmusl flags: - - -tags - - sqlite + - -tags=sqlite,hsm ldflags: - -s -w -X github.com/ory/hydra/driver/config.Version={{.Tag}} -X github.com/ory/hydra/driver/config.Commit={{.FullCommit}} -X github.com/ory/hydra/driver/config.Date={{.Date}} binary: hydra @@ -61,8 +58,7 @@ builds: - id: hydra-sqlite-windows flags: - - -tags - - sqlite + - -tags=sqlite,hsm ldflags: - -s -w -X github.com/ory/hydra/driver/config.Version={{.Tag}} -X github.com/ory/hydra/driver/config.Commit={{.FullCommit}} -X github.com/ory/hydra/driver/config.Date={{.Date}} - "-extldflags '-static'" diff --git a/Makefile b/Makefile index 406a8b26e19..41ae5a42310 100644 --- a/Makefile +++ b/Makefile @@ -81,6 +81,10 @@ e2e: node_modules test-resetdb quicktest: go test -failfast -short -tags sqlite ./... +.PHONY: quicktest-hsm +quicktest-hsm: + docker build --progress=plain -f .docker/Dockerfile-hsm --target test-hsm . + # Formats the code .PHONY: format format: .bin/goimports node_modules docs/node_modules contributors diff --git a/cmd/cli/handler_import_jwk_test.go b/cmd/cli/handler_import_jwk_test.go index 7c535af602c..cf35f31b4b7 100644 --- a/cmd/cli/handler_import_jwk_test.go +++ b/cmd/cli/handler_import_jwk_test.go @@ -23,6 +23,10 @@ func TestImportJSONWebKey(t *testing.T) { reg := internal.NewRegistryMemory(t, conf) router := x.NewRouterPublic() + if conf.HsmEnabled() { + t.Skip("Skipping test. Keys cannot be imported when Hardware Security Module is enabled") + } + h := reg.KeyHandler() m := reg.KeyManager() diff --git a/cmd/root_test.go b/cmd/root_test.go index 772a933f1bc..3eeb5b13076 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -30,6 +30,8 @@ import ( "testing" "time" + "github.com/ory/hydra/internal" + "github.com/phayes/freeport" "github.com/stretchr/testify/assert" @@ -73,6 +75,7 @@ func init() { func TestExecute(t *testing.T) { frontend := fmt.Sprintf("https://localhost:%d/", frontendPort) backend := fmt.Sprintf("https://localhost:%d/", backendPort) + conf := internal.NewConfigurationWithDefaults() rootCmd := NewRootCmd() @@ -80,6 +83,7 @@ func TestExecute(t *testing.T) { args []string wait func() bool expectErr bool + skipTest bool }{ { args: []string{"serve", "all", "--sqa-opt-out"}, @@ -116,14 +120,15 @@ func TestExecute(t *testing.T) { {args: []string{"clients", "create", "--skip-tls-verify", "--endpoint", backend, "--id", "public-foo"}}, {args: []string{"clients", "create", "--skip-tls-verify", "--endpoint", backend, "--id", "confidential-foo", "--pgp-key", base64EncodedPGPPublicKey(t), "--grant-types", "client_credentials", "--response-types", "token"}}, {args: []string{"clients", "delete", "--skip-tls-verify", "--endpoint", backend, "public-foo"}}, - {args: []string{"keys", "create", "--skip-tls-verify", "foo", "--endpoint", backend, "-a", "HS256"}}, + {args: []string{"keys", "create", "--skip-tls-verify", "foo", "--endpoint", backend, "-a", "RS256"}}, + {args: []string{"keys", "create", "--skip-tls-verify", "foo", "--endpoint", backend, "-a", "HS256"}, skipTest: conf.HsmEnabled()}, {args: []string{"keys", "get", "--skip-tls-verify", "--endpoint", backend, "foo"}}, // {args: []string{"keys", "rotate", "--skip-tls-verify", "--endpoint", backend, "foo"}}, {args: []string{"keys", "get", "--skip-tls-verify", "--endpoint", backend, "foo"}}, {args: []string{"keys", "delete", "--skip-tls-verify", "--endpoint", backend, "foo"}}, - {args: []string{"keys", "import", "--skip-tls-verify", "--endpoint", backend, "import-1", "../test/stub/ecdh.key", "../test/stub/ecdh.pub"}}, - {args: []string{"keys", "import", "--skip-tls-verify", "--endpoint", backend, "import-2", "../test/stub/rsa.key", "../test/stub/rsa.pub"}}, - {args: []string{"keys", "import", "--skip-tls-verify", "--endpoint", backend, "import-2", "../test/stub/rsa.key", "../test/stub/rsa.pub"}}, + {args: []string{"keys", "import", "--skip-tls-verify", "--endpoint", backend, "import-1", "../test/stub/ecdh.key", "../test/stub/ecdh.pub"}, skipTest: conf.HsmEnabled()}, + {args: []string{"keys", "import", "--skip-tls-verify", "--endpoint", backend, "import-2", "../test/stub/rsa.key", "../test/stub/rsa.pub"}, skipTest: conf.HsmEnabled()}, + {args: []string{"keys", "import", "--skip-tls-verify", "--endpoint", backend, "import-2", "../test/stub/rsa.key", "../test/stub/rsa.pub"}, skipTest: conf.HsmEnabled()}, {args: []string{"token", "revoke", "--skip-tls-verify", "--endpoint", frontend, "--client-secret", "foobar", "--client-id", "foobarbaz", "foo"}}, {args: []string{"token", "client", "--skip-tls-verify", "--endpoint", frontend, "--client-secret", "foobar", "--client-id", "foobarbaz"}}, {args: []string{"help", "migrate", "sql"}}, @@ -133,6 +138,10 @@ func TestExecute(t *testing.T) { rootCmd.SetArgs(c.args) t.Run(fmt.Sprintf("command=%v", c.args), func(t *testing.T) { + if c.skipTest { + t.Skip("Skipping test. Not applicable when Hardware Security Module is enabled") + } + if c.wait != nil { go func() { assert.Nil(t, rootCmd.Execute()) diff --git a/cmd/server/helper_cert.go b/cmd/server/helper_cert.go index 6daa240f6c9..992a1c51852 100644 --- a/cmd/server/helper_cert.go +++ b/cmd/server/helper_cert.go @@ -42,7 +42,7 @@ import ( ) const ( - tlsKeyName = "hydra.https-tls" + TlsKeyName = "hydra.https-tls" ) func AttachCertificate(priv *jose.JSONWebKey, cert *x509.Certificate) { @@ -63,9 +63,9 @@ func GetOrCreateTLSCertificate(cmd *cobra.Command, d driver.Registry, iface conf d.Logger().WithError(err).Fatalf("Unable to load HTTPS TLS Certificate") } - _, priv, err := jwk.AsymmetricKeypair(context.Background(), d, &jwk.RS256Generator{KeyLength: 4069}, tlsKeyName) + _, priv, err := jwk.GetOrGenerateKeys(context.Background(), d, d.SoftwareKeyManager(), TlsKeyName, TlsKeyName, "RS256") if err != nil { - d.Logger().WithError(err).Fatal("Unable to fetch HTTPS TLS key pairs") + d.Logger().WithError(err).Fatal("Unable to fetch or generate HTTPS TLS key pair") } if len(priv.Certificates) == 0 { @@ -75,11 +75,11 @@ func GetOrCreateTLSCertificate(cmd *cobra.Command, d driver.Registry, iface conf } AttachCertificate(priv, cert) - if err := d.KeyManager().DeleteKey(context.TODO(), tlsKeyName, priv.KeyID); err != nil { + if err := d.SoftwareKeyManager().DeleteKey(context.TODO(), TlsKeyName, priv.KeyID); err != nil { d.Logger().WithError(err).Fatal(`Could not update (delete) the self signed TLS certificate`) } - if err := d.KeyManager().AddKey(context.TODO(), tlsKeyName, priv); err != nil { + if err := d.SoftwareKeyManager().AddKey(context.TODO(), TlsKeyName, priv); err != nil { d.Logger().WithError(err).Fatalf(`Could not update (add) the self signed TLS certificate: %s %x %d`, cert.SignatureAlgorithm, cert.Signature, len(cert.Signature)) } } diff --git a/docs/config.js b/docs/config.js index 3cd4fd501e9..66e8876f9f0 100644 --- a/docs/config.js +++ b/docs/config.js @@ -1,5 +1,5 @@ module.exports = { - projectName: 'ORY Hydra', + projectName: 'Ory Hydra', projectSlug: 'hydra', newsletter: 'https://ory.us10.list-manage.com/subscribe?u=ffb1a878e4ec6c0ed312a3480&id=f605a41b53&group[17097][8]=1', diff --git a/docs/docs/.static/api.json b/docs/docs/.static/api.json index 28f8483f1c4..0f9288f2a4d 100755 --- a/docs/docs/.static/api.json +++ b/docs/docs/.static/api.json @@ -526,7 +526,7 @@ } }, "put": { - "description": "Use this method if you do not want to let Hydra generate the JWKs for you, but instead save your own.\n\nA JSON Web Key (JWK) is a JavaScript Object Notation (JSON) data structure that represents a cryptographic key. A JWK Set is a JSON data structure that represents a set of JWKs. A JSON Web Key is identified by its set and key id. ORY Hydra uses this functionality to store cryptographic keys used for TLS and JSON Web Tokens (such as OpenID Connect ID tokens), and allows storing user-defined keys as well.", + "description": "Use this method if you do not want to let Hydra generate the JWKs for you, but instead save your own.\n\nA JSON Web Key (JWK) is a JavaScript Object Notation (JSON) data structure that represents a cryptographic key. A JWK Set is a JSON data structure that represents a set of JWKs. A JSON Web Key is identified by its set and key id. ORY Hydra uses this functionality to store cryptographic keys used for TLS and JSON Web Tokens (such as OpenID Connect ID tokens), and allows storing user-defined keys as well. This method is not supported when Hardware Security Module is enabled.", "consumes": ["application/json"], "produces": ["application/json"], "schemes": ["http", "https"], @@ -577,7 +577,7 @@ } }, "post": { - "description": "This endpoint is capable of generating JSON Web Key Sets for you. There a different strategies available, such as symmetric cryptographic keys (HS256, HS512) and asymetric cryptographic keys (RS256, ECDSA). If the specified JSON Web Key Set does not exist, it will be created.\n\nA JSON Web Key (JWK) is a JavaScript Object Notation (JSON) data structure that represents a cryptographic key. A JWK Set is a JSON data structure that represents a set of JWKs. A JSON Web Key is identified by its set and key id. ORY Hydra uses this functionality to store cryptographic keys used for TLS and JSON Web Tokens (such as OpenID Connect ID tokens), and allows storing user-defined keys as well.", + "description": "This endpoint is capable of generating JSON Web Key Sets for you. There a different strategies available, such as symmetric cryptographic keys (HS256, HS512) and asymetric cryptographic keys (RS256, ECDSA, EdDSA). When Hardware Security Module is enabled, then only RS256, ECDSA key strategies are available. If the specified JSON Web Key Set does not exist, it will be created.\n\nA JSON Web Key (JWK) is a JavaScript Object Notation (JSON) data structure that represents a cryptographic key. A JWK Set is a JSON data structure that represents a set of JWKs. A JSON Web Key is identified by its set and key id. ORY Hydra uses this functionality to store cryptographic keys used for TLS and JSON Web Tokens (such as OpenID Connect ID tokens), and allows storing user-defined keys as well.", "consumes": ["application/json"], "produces": ["application/json"], "schemes": ["http", "https"], @@ -716,7 +716,7 @@ } }, "put": { - "description": "Use this method if you do not want to let Hydra generate the JWKs for you, but instead save your own.\n\nA JSON Web Key (JWK) is a JavaScript Object Notation (JSON) data structure that represents a cryptographic key. A JWK Set is a JSON data structure that represents a set of JWKs. A JSON Web Key is identified by its set and key id. ORY Hydra uses this functionality to store cryptographic keys used for TLS and JSON Web Tokens (such as OpenID Connect ID tokens), and allows storing user-defined keys as well.", + "description": "Use this method if you do not want to let Hydra generate the JWKs for you, but instead save your own.\n\nA JSON Web Key (JWK) is a JavaScript Object Notation (JSON) data structure that represents a cryptographic key. A JWK Set is a JSON data structure that represents a set of JWKs. A JSON Web Key is identified by its set and key id. ORY Hydra uses this functionality to store cryptographic keys used for TLS and JSON Web Tokens (such as OpenID Connect ID tokens), and allows storing user-defined keys as well. This method is not supported when Hardware Security Module is enabled.", "consumes": ["application/json"], "produces": ["application/json"], "schemes": ["http", "https"], diff --git a/docs/docs/5min-tutorial.mdx b/docs/docs/5min-tutorial.mdx index 3b3fa78e038..fe4776bef32 100644 --- a/docs/docs/5min-tutorial.mdx +++ b/docs/docs/5min-tutorial.mdx @@ -77,6 +77,15 @@ $ docker-compose -f quickstart.yml \ up --build ``` +If you want to test Hardware Security Module add `-f quickstart-hsm.yml`. For +more information head over to [HSM support](hsm-support). + +```shell script +$ docker-compose -f quickstart.yml \ + -f quickstart-hsm.yml \ + up --build +``` + Let's confirm that everything is working by creating an OAuth 2.0 Client. Note: The following commands run Hydra inside Docker. If you have the ORY Hydra @@ -186,9 +195,8 @@ $ docker-compose -f quickstart.yml rm -f -v ### Quickstart Configuration -In this tutorial we use a simplified configuration. -You can find it in -[`contrib/quickstart/5-min/hydra.yml`](https://github.com/ory/hydra/blob/master/contrib/quickstart/5-min/hydra.yml). +In this tutorial we use a simplified configuration. You can find it in +[`contrib/quickstart/5-min/hydra.yml`](https://github.com/ory/hydra/blob/master/contrib/quickstart/5-min/hydra.yml). The configuration gets loaded in docker-compose as specified in the [`quickstart.yml`](https://github.com/ory/hydra/blob/master/quickstart.yml). diff --git a/docs/docs/guides/hsm-support.md b/docs/docs/guides/hsm-support.md new file mode 100644 index 00000000000..4d71ff8baa4 --- /dev/null +++ b/docs/docs/guides/hsm-support.md @@ -0,0 +1,224 @@ +--- +id: hsm-support +title: Hardware Security Module support for JSON Web Key Sets +--- + +The +[PKCS#11 Cryptographic Token Interface Standard](http://docs.oasis-open.org/pkcs11/pkcs11-base/v2.40/os/pkcs11-base-v2.40-os.html), +also known as Cryptoki, is one of the Public Key Cryptography Standards +developed by RSA Security. PKCS#11 defines the interface between an application +and a cryptographic device. + +:::note + +If a key is not found in the Hardware Security Module, the regular Software Key +Manager with AES-GCM software encryption will be used as a fallback. Storing +keys will always use the Software Key Manager as it is not possible to add keys +to a Hardware Security Module. + +::: + +PKCS#11 is used as a low-level interface to perform cryptographic operations +without the need for the application to directly interface a device through its +driver. PKCS#11 represents cryptographic devices using a common model referred +to simply as a token. An application can therefore perform cryptographic +operations on any device or token, using the same independent command set. + + + +### Hardware Security Module configuration + +Ory Hydra can be configured using environment variables as well as a +configuration file. For more information on configuration options, open the +configuration documentation: + +>> https://www.ory.sh/hydra/docs/reference/configuration << + +``` +HSM_ENABLED=true +HSM_LIBRARY=/path/to/hsm-vendor/library.so +HSM_TOKEN_LABEL=hydra +HSM_SLOT=0 +HSM_PIN=1234 +``` + +Token that is denoted by environment variables `HSM_TOKEN_LABEL` or `HSM_SLOT` +must preexist and optionally contain RSA or ECDSA key pairs with labels +`hydra.openid.id-token` and `hydra.jwt.access-token` depending on configuration. +**_If keys with these labels don't exist, they will be generated upon +startup._** If both `HSM_TOKEN_LABEL` and `HSM_SLOT` are set, `HSM_TOKEN_LABEL` +takes preference over `HSM_SLOT`. In this case first slot that contains this +label is used. `HSM_LIBRARY` must point to vendor specific PKCS#11 library or +SoftHSM library if you want to [test HSM support](#testing-with-softhsm). + + + +### PKCS#11 attribute mappings to JSON Web Key Set attributes + +When key pair is generated or requested from HSM, the `CKA_LABEL` attribute is +used as JSON Web Key Set name, `CKA_ID` attribute as `kid`. Key usage is +determined by private key attributes, where `CKA_SIGN` and `CKA_DECRYPT` are +mapped to `sig` and `enc` respectively and set as key `use` attribute. +Furthermore, `CKA_ID's` of key pair private/public handles must be identical. +Attribute `alg` is determined from `CKA_KEY_TYPE` and `CKA_ECDSA_PARAMS`. + + + +### Supported key algorithms + +Ory Hydra supports generating 4096 bit RSA, ECDSA keys with curves secp256r1 or +secp521r1. As of now PKCS#11 v2.4 doesn't support EdDSA keys using curve +Ed25519. However, +[PKCS#11 v3.0](https://docs.oasis-open.org/pkcs11/pkcs11-curr/v3.0/pkcs11-curr-v3.0.html) +contains support for EdDSA and therefore can be supported in upcoming versions. +Symmetric key algorithms are not supported because it would imply, that shared +HSM is used between server and authenticating client. + + + +### Generating key pairs + + + +#### Initializing token + +Different policies can apply for tokens, therefore HSM configuration expects, +that token where to find or generate keys already exists. Depending on HSM +vendor, tools initializing tokens and generating keys vary. To demonstrate key +pair generation we first initialize token using `pkcs11-tool` (see how to +[setup SoftHSM and OpenSC](#testing-with-softhsm)) + +```shell +$ pkcs11-tool --module /usr/lib/softhsm/libsofthsm2.so --slot 0 --init-token --so-pin 0000 --pin 1234 --init-pin --label hydra + +Using slot 0 with a present token (0x2763db07) +Token successfully initialized +User PIN successfully initialized +``` + +Corresponding Ory Hydra configuration to access this token would be + +``` +HSM_ENABLED=true +HSM_LIBRARY=/usr/lib/softhsm/libsofthsm2.so +HSM_TOKEN_LABEL=hydra +HSM_SLOT=0 +HSM_PIN=1234 +``` + + + +#### Generating key pair + +Generating RSA keypair for JSON Web Key `hydra.openid.id-token` + +```shell +$ pkcs11-tool --module /usr/lib/softhsm/libsofthsm2.so \ +--pin 1234 --token-label hydra \ +--keypairgen --key-type rsa:4096 --usage-sign \ +--label hydra.openid.id-token --id 746573742d6b65792d6964 + +Key pair generated: +Private Key Object; RSA + label: hydra.openid.id-token + ID: 746573742d6b65792d6964 + Usage: sign +Public Key Object; RSA 4096 bits + label: hydra.openid.id-token + ID: 746573742d6b65792d6964 + Usage: verify +``` + +| Parameter | Description | +| :---------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `--module` | Points to vendor specific PKCS#11 library or SoftHSM library when testing. | +| `--pin 1234` | Pin that was used in token initialization to perform token operations. | +| `--token-label hydra` | Performs key generation on first slot with label `hydra`. Use `--slot` option instead if you want to specify specific slot. | +| `--label hydra.openid.id-token` | Sets key pair label attribute `CKA_LABEL` and is used as JSON Web Key Set name. | +| `--id 746573742d6b65792d6964` | Sets key pair id attribute `CKA_ID` and is used as JSON Web Key Set `kid`. It must be set as a big-endian hexadecimal integer value. `StringToHex("test-key-id") == 746573742d6b65792d6964` | +| `--keypairgen` | Perform key pair generation on token | +| `--key-type rsa:4096` | Type and length of the key to generate. Supported values are `rsa:4096`, `EC:secp256r1` or `EC:secp521r1`. Sets `CKA_KEY_TYPE`,`CKA_ECDSA_PARAMS` attributes and is used to determine JSON Web Key Set `alg` attribute. | +| `--usage-sign` or `--usage-decrypt` | Sets private key attribute `CKA_SIGN` or `CKA_DECRYPT` respectively. Used to determine JSON Web Key Set `use` attribute. | + + + +##### Key type mappings + +| Key type | JWT signing algorithm | +| :----------- | :-------------------- | +| rsa:4096 | RS256 | +| EC:secp256r1 | ES256 | +| EC:secp521r1 | ES512 | + + + +### Testing with SoftHSM + +[SoftHSM](https://www.opendnssec.org/softhsm/) is an implementation of a +cryptographic store accessible through a PKCS #11 interface. You can use it to +explore PKCS#11 without having a Hardware Security Module. It is being developed +as a part of the OpenDNSSEC project. + +[Follow these instructions to build SoftHSM from source.](https://wiki.opendnssec.org/display/SoftHSMDOCS/SoftHSM+Documentation+v2) + +#### Install SoftHSM/OpenSC on Mac OSX + +```shell +$ ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)" 2> /dev/null +``` + +```shell +$ brew install softhsm +``` + +```shell +$ ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)" 2> /dev/null +``` + +```shell +$ brew install opensc +``` + +#### Install SoftHSM/OpenSC on Ubuntu + +```shell +$ sudo apt update +``` + +```shell +$ sudo apt install softhsm opensc +``` + +#### Install SoftHSM/OpenSC on Windows + +Follow these instructions to install +[SoftHSM](https://github.com/disig/SoftHSM2-for-Windows) and +[OpenSC](https://github.com/OpenSC/OpenSC/wiki) on windows. + +#### Run Ory Hydra with HSM using Docker + +Alternatively you can use quickstart docker container that setups +SoftHSM/OpenSC, builds and runs Ory Hydra with HSM configuration enabled. You +need to have the latest [Docker](https://www.docker.com) and +[Docker Compose](https://docs.docker.com/compose) version installed. To run +quickstart HSM change into the directory with the Hydra source code and run the +following command: + +```shell +$ docker-compose -f quickstart-hsm.yml up --build +``` + +Following is logged on startup if Hardware Security Module is successfully +configured: + +```shell +$ docker logs ory-hydra-example--hydra +time="2021-07-07T12:51:23Z" level=info msg="Hardware Security Module is configured." +time="2021-07-07T12:51:23Z" level=info msg="JSON Web Key Set 'hydra.openid.id-token' does not exist yet, generating new key pair..." +``` + +#### Run Tests with HSM enabled using Docker + +```shell +$ make quicktest-hsm +``` diff --git a/docs/docs/index.md b/docs/docs/index.md index 291111ff66c..cecf4dfbb63 100644 --- a/docs/docs/index.md +++ b/docs/docs/index.md @@ -38,6 +38,11 @@ In addition to the OAuth 2.0 functionality, ORY Hydra offers a safe storage for cryptographic keys (used for example to sign JSON Web Tokens) and can manage OAuth 2.0 Clients. +### Hardware Security Module support + +ORY Hydra also offers a safe storage for cryptographic keys using HSM. +[Learn more](guides/hsm-support.md). + ## Security First ORY Hydra's architecture and work flows are designed to neutralize many common diff --git a/docs/docs/reference/configuration.md b/docs/docs/reference/configuration.md index f788f8bf552..d51a3fe06ea 100644 --- a/docs/docs/reference/configuration.md +++ b/docs/docs/reference/configuration.md @@ -714,6 +714,31 @@ serve: # dsn: '' +## hsm ## +# Configures Hardware Security Module for hydra.openid.id-token, hydra.jwt.access-token keys +# Either slot or token_label must be set. If token_label is set, then first slot in index with this label is used. +# +# Set this value using environment variables on +# - Linux/macOS: +# $ export HSM_ENABLED= +# $ export HSM_LIBRARY= +# $ export HSM_PIN= +# $ export HSM_SLOT= +# $ export HSM_TOKEN_LABEL= +# - Windows Command Line (CMD): +# > set HSM_ENABLED= +# > set HSM_LIBRARY= +# > set HSM_PIN= +# > set HSM_SLOT= +# > set HSM_TOKEN_LABEL= +# +hsm: + enabled: false + library: /path/to/hsm-vendor/library.so + pin: partition-pin-code + slot: 0 + token_label: hydra + ## webfinger ## # # Configures ./well-known/ settings. diff --git a/docs/sidebar.json b/docs/sidebar.json index a1f950dc1dd..65e91c4c407 100644 --- a/docs/sidebar.json +++ b/docs/sidebar.json @@ -29,6 +29,7 @@ "dependencies-environment", "production", "guides/tracing", + "guides/hsm-support", "guides/secrets-key-rotation", "guides/kubernetes-helm-chart", "guides/ssl-https-tls", diff --git a/driver/config/provider.go b/driver/config/provider.go index c2c65f66e77..f816da2fe60 100644 --- a/driver/config/provider.go +++ b/driver/config/provider.go @@ -24,6 +24,11 @@ import ( const ( KeyRoot = "" + HsmEnabled = "hsm.enabled" + HsmLibraryPath = "hsm.library" + HsmPin = "hsm.pin" + HsmSlotNumber = "hsm.slot" + HsmTokenLabel = "hsm.token_label" // #nosec G101 KeyWellKnownKeys = "webfinger.jwks.broadcast_keys" KeyOAuth2ClientRegistrationURL = "webfinger.oidc_discovery.client_registration_url" KeyOAuth2TokenURL = "webfinger.oidc_discovery.token_url" // #nosec G101 @@ -442,6 +447,27 @@ func (p *Provider) GrantAllClientCredentialsScopesPerDefault() bool { return p.p.Bool(KeyGrantAllClientCredentialsScopesPerDefault) } +func (p *Provider) HsmEnabled() bool { + return p.p.Bool(HsmEnabled) +} + +func (p *Provider) HsmLibraryPath() string { + return p.p.String(HsmLibraryPath) +} + +func (p *Provider) HsmSlotNumber() *int { + n := p.p.Int(HsmSlotNumber) + return &n +} + +func (p *Provider) HsmPin() string { + return p.p.String(HsmPin) +} + +func (p *Provider) HsmTokenLabel() string { + return p.p.String(HsmTokenLabel) +} + func (p *Provider) GrantTypeJWTBearerIDOptional() bool { return p.p.Bool(KeyOAuth2GrantJWTIDOptional) } diff --git a/driver/registry.go b/driver/registry.go index 1309e00b59c..4bea95f01b4 100644 --- a/driver/registry.go +++ b/driver/registry.go @@ -3,6 +3,8 @@ package driver import ( "context" + "github.com/ory/hydra/hsm" + "github.com/ory/hydra/oauth2/trust" "github.com/pkg/errors" @@ -37,6 +39,7 @@ type Registry interface { WithBuildInfo(v, h, d string) Registry WithConfig(c *config.Provider) Registry WithLogger(l *logrusx.Logger) Registry + WithKeyGenerators(kg map[string]jwk.KeyGenerator) Registry Config() *config.Provider persistence.Provider @@ -61,6 +64,7 @@ type Registry interface { OAuth2HMACStrategy() *foauth2.HMACSHAStrategy WithOAuth2Provider(f fosite.OAuth2Provider) WithConsentStrategy(c consent.Strategy) + WithHsmContext(h hsm.Context) } func NewRegistryFromDSN(ctx context.Context, c *config.Provider, l *logrusx.Logger) (Registry, error) { diff --git a/driver/registry_base.go b/driver/registry_base.go index a593121c425..93f29659094 100644 --- a/driver/registry_base.go +++ b/driver/registry_base.go @@ -8,6 +8,8 @@ import ( "strings" "time" + "github.com/ory/hydra/hsm" + prometheus "github.com/ory/x/prometheusx" "github.com/pkg/errors" @@ -61,6 +63,7 @@ type RegistryBase struct { fsc fosite.ScopeStrategy atjs jwk.JWTStrategy idtjs jwk.JWTStrategy + hsm hsm.Context fscPrev string fos *openid.DefaultStrategy forv *openid.OpenIDConnectRequestValidator @@ -131,6 +134,11 @@ func (m *RegistryBase) WithConfig(c *config.Provider) Registry { return m.r } +func (m *RegistryBase) WithKeyGenerators(kg map[string]jwk.KeyGenerator) Registry { + m.kg = kg + return m.r +} + func (m *RegistryBase) Writer() herodot.Writer { if m.writer == nil { h := herodot.NewJSONWriter(m.Logger()) @@ -370,7 +378,8 @@ func (m *RegistryBase) ScopeStrategy() fosite.ScopeStrategy { } func (m *RegistryBase) newKeyStrategy(key string) (s jwk.JWTStrategy) { - if err := jwk.EnsureAsymmetricKeypairExists(context.Background(), m.r, new(jwk.RS256Generator), key); err != nil { + + if err := jwk.EnsureAsymmetricKeypairExists(context.Background(), m.r, "RS256", key); err != nil { var netError net.Error if errors.As(err, &netError) { m.Logger().WithError(err).Fatalf(`Could not ensure that signing keys for "%s" exists. A network error occurred, see error for specific details.`, key) @@ -381,7 +390,7 @@ func (m *RegistryBase) newKeyStrategy(key string) (s jwk.JWTStrategy) { } if err := resilience.Retry(m.Logger(), time.Second*15, time.Minute*15, func() (err error) { - s, err = jwk.NewRS256JWTStrategy(m.r, func() string { + s, err = jwk.NewRS256JWTStrategy(*m.C, m.r, func() string { return key }) return err @@ -502,6 +511,17 @@ func (m *RegistryBase) AccessRequestHooks() []oauth2.AccessRequestHook { return m.arhs } +func (m *RegistryBase) WithHsmContext(h hsm.Context) { + m.hsm = h +} + +func (m *RegistryBase) HsmContext() hsm.Context { + if m.hsm == nil { + m.hsm = hsm.NewContext(m.C, m.l) + } + return m.hsm +} + func (m *RegistrySQL) ClientAuthenticator() x.ClientAuthenticator { return m.OAuth2Provider().(*fosite.Fosite) } diff --git a/driver/registry_base_test.go b/driver/registry_base_test.go index 3e1fb19b679..2fbc83e805e 100644 --- a/driver/registry_base_test.go +++ b/driver/registry_base_test.go @@ -29,6 +29,7 @@ func TestRegistryBase_newKeyStrategy_handlesNetworkError(t *testing.T) { // Create a config and set a valid but unresolvable DSN c := config.MustNew(l, configx.WithConfigFiles("../internal/.hydra.yaml")) c.MustSet(config.KeyDSN, "postgres://user:password@127.0.0.1:9999/postgres") + c.MustSet(config.HsmEnabled, "false") registry, err := NewRegistryFromDSN(context.Background(), c, l) if err != nil { @@ -37,6 +38,7 @@ func TestRegistryBase_newKeyStrategy_handlesNetworkError(t *testing.T) { } registryBase := RegistryBase{r: registry, l: l} + registryBase.WithConfig(c) strategy := registryBase.newKeyStrategy("key") diff --git a/driver/registry_sql.go b/driver/registry_sql.go index 1bb429ec048..96583ba5008 100644 --- a/driver/registry_sql.go +++ b/driver/registry_sql.go @@ -5,6 +5,8 @@ import ( "strings" "time" + "github.com/ory/hydra/hsm" + "github.com/gobuffalo/pop/v6" "github.com/ory/hydra/oauth2/trust" @@ -32,7 +34,8 @@ import ( type RegistrySQL struct { *RegistryBase - db *sqlx.DB + db *sqlx.DB + defaultKeyManager jwk.Manager } var _ Registry = new(RegistrySQL) @@ -82,6 +85,13 @@ func (m *RegistrySQL) Init(ctx context.Context) error { return err } + if m.C.HsmEnabled() { + hardwareKeyManager := hsm.NewKeyManager(m.HsmContext()) + m.defaultKeyManager = jwk.NewManagerStrategy(hardwareKeyManager, m.persister) + } else { + m.defaultKeyManager = m.persister + } + // if dsn is memory we have to run the migrations on every start // use case - such as // - just in memory @@ -123,6 +133,10 @@ func (m *RegistrySQL) OAuth2Storage() x.FositeStorer { } func (m *RegistrySQL) KeyManager() jwk.Manager { + return m.defaultKeyManager +} + +func (m *RegistrySQL) SoftwareKeyManager() jwk.Manager { return m.Persister() } diff --git a/driver/registry_sql_test.go b/driver/registry_sql_test.go new file mode 100644 index 00000000000..cddaab34239 --- /dev/null +++ b/driver/registry_sql_test.go @@ -0,0 +1,24 @@ +package driver + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/ory/hydra/driver/config" + "github.com/ory/hydra/persistence/sql" + "github.com/ory/x/configx" + "github.com/ory/x/logrusx" +) + +func TestDefaultKeyManager_HsmDisabled(t *testing.T) { + l := logrusx.New("", "") + c := config.MustNew(l, configx.SkipValidation()) + c.MustSet(config.KeyDSN, "postgres://user:password@127.0.0.1:9999/postgres") + c.MustSet(config.HsmEnabled, "false") + reg, err := NewRegistryFromDSN(context.Background(), c, l) + assert.NoError(t, err) + assert.IsType(t, &sql.Persister{}, reg.KeyManager()) + assert.IsType(t, &sql.Persister{}, reg.SoftwareKeyManager()) +} diff --git a/go.mod b/go.mod index 649e4f57127..62a2962cfba 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ replace ( ) require ( + github.com/ThalesIgnite/crypto11 v1.2.4 github.com/cenkalti/backoff/v3 v3.0.0 github.com/evanphx/json-patch v4.9.0+incompatible github.com/go-bindata/go-bindata v3.1.2+incompatible @@ -38,11 +39,12 @@ require ( github.com/julienschmidt/httprouter v1.3.0 github.com/luna-duclos/instrumentedsql v1.1.3 github.com/luna-duclos/instrumentedsql/opentracing v0.0.0-20201103091713-40d03108b6f4 + github.com/miekg/pkcs11 v1.0.3 github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 github.com/oleiade/reflections v1.0.1 github.com/olekukonko/tablewriter v0.0.1 github.com/ory/analytics-go/v4 v4.0.2 - github.com/ory/fosite v0.40.3-0.20210927193520-47901ddecc68 + github.com/ory/fosite v0.40.3-0.20211013150831-5027277a8297 github.com/ory/go-acc v0.2.6 github.com/ory/graceful v0.1.1 github.com/ory/herodot v0.9.12 diff --git a/go.sum b/go.sum index e4c53e77243..373a0099bb9 100644 --- a/go.sum +++ b/go.sum @@ -105,6 +105,8 @@ github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdko github.com/Shopify/logrus-bugsnag v0.0.0-20171204204709-577dee27f20d/go.mod h1:HI8ITrYtUY+O+ZhtlqUnD8+KwNPOyugEhfP9fdUIaEQ= github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo= github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= +github.com/ThalesIgnite/crypto11 v1.2.4 h1:3MebRK/U0mA2SmSthXAIZAdUA9w8+ZuKem2O6HuR1f8= +github.com/ThalesIgnite/crypto11 v1.2.4/go.mod h1:ILDKtnCKiQ7zRoNxcp36Y1ZR8LBPmR2E23+wTQe/MlE= github.com/VividCortex/gohistogram v1.0.0/go.mod h1:Pf5mBqqDxYaXu3hDrrU+w6nw50o/4+TcAqDqk/vUH7g= github.com/aeneasr/cupaloy/v2 v2.6.1-0.20210924214125-3dfdd01210a3 h1:/SkiUr3JJzun9QN9cpUVCPri2ZwOFJ3ani+F3vdoCiY= github.com/aeneasr/cupaloy/v2 v2.6.1-0.20210924214125-3dfdd01210a3/go.mod h1:bm7JXdkRd4BHJk9HpwqAI8BoAY1lps46Enkdqw6aRX0= @@ -1187,6 +1189,8 @@ github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/le github.com/microcosm-cc/bluemonday v1.0.16 h1:kHmAq2t7WPWLjiGvzKa5o3HzSfahUKiOq7fAPUiMNIc= github.com/microcosm-cc/bluemonday v1.0.16/go.mod h1:Z0r70sCuXHig8YpBzCc5eGHAap2K7e/u082ZUpDRRqM= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= +github.com/miekg/pkcs11 v1.0.3-0.20190429190417-a667d056470f/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= +github.com/miekg/pkcs11 v1.0.3 h1:iMwmD7I5225wv84WxIG/bmxz9AXjWvTWIbM/TYHvWtw= github.com/miekg/pkcs11 v1.0.3/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= github.com/mistifyio/go-zfs v2.1.2-0.20190413222219-f784269be439+incompatible/go.mod h1:8AuVvqP/mXw1px98n46wfvcGfQ4ci2FwoAjKYxuo3Z4= github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= @@ -1335,8 +1339,8 @@ github.com/ory/dockertest/v3 v3.6.5/go.mod h1:iYKQSRlYrt/2s5fJWYdB98kCQG6g/LjBMv github.com/ory/dockertest/v3 v3.8.1 h1:vU/8d1We4qIad2YM0kOwRVtnyue7ExvacPiw1yDm17g= github.com/ory/dockertest/v3 v3.8.1/go.mod h1:wSRQ3wmkz+uSARYMk7kVJFDBGm8x5gSxIhI7NDc+BAQ= github.com/ory/fosite v0.29.0/go.mod h1:0atSZmXO7CAcs6NPMI/Qtot8tmZYj04Nddoold4S2h0= -github.com/ory/fosite v0.40.3-0.20210927193520-47901ddecc68 h1:jTHWt0Yh4UmIB4lNLbvM1Yb4/RRYA9lvKTJ2J830Wdo= -github.com/ory/fosite v0.40.3-0.20210927193520-47901ddecc68/go.mod h1:IIRYBnuhyfgmYpSwk1h56+2CI7p+605KRCiJ7olUcl0= +github.com/ory/fosite v0.40.3-0.20211013150831-5027277a8297 h1:r8t/5GYtFx8dY+OuebrxbmCh+sL9B9KW1gc4xCy9hCE= +github.com/ory/fosite v0.40.3-0.20211013150831-5027277a8297/go.mod h1:IIRYBnuhyfgmYpSwk1h56+2CI7p+605KRCiJ7olUcl0= github.com/ory/go-acc v0.0.0-20181118080137-ddc355013f90/go.mod h1:sxnvPCxChFuSmTJGj8FdMupeq1BezCiEpDjTUXQ4hf4= github.com/ory/go-acc v0.2.6 h1:YfI+L9dxI7QCtWn2RbawqO0vXhiThdXu/RgizJBbaq0= github.com/ory/go-acc v0.2.6/go.mod h1:4Kb/UnPcT8qRAk3IAxta+hvVapdxTLWtrr7bFLlEgpw= @@ -1620,6 +1624,8 @@ github.com/syndtr/gocapability v0.0.0-20170704070218-db04d3cc01c8/go.mod h1:hkRG github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww= github.com/syndtr/gocapability v0.0.0-20200815063812-42c35b437635/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww= github.com/tchap/go-patricia v2.2.6+incompatible/go.mod h1:bmLyhP68RS6kStMGxByiQ23RP/odRBOTVjwp2cDyi6I= +github.com/thales-e-security/pool v0.0.2 h1:RAPs4q2EbWsTit6tpzuvTFlgFRJ3S8Evf5gtvVDbmPg= +github.com/thales-e-security/pool v0.0.2/go.mod h1:qtpMm2+thHtqhLzTwgDBj/OuNnMpupY8mv0Phz0gjhU= github.com/tidwall/gjson v1.3.2/go.mod h1:P256ACg0Mn+j1RXIDXoss50DeIABTYK1PULOJHhxOls= github.com/tidwall/gjson v1.6.8/go.mod h1:zeFuBCIqD4sN/gmqBzZ4j7Jd6UcA2Fc56x7QFsv+8fI= github.com/tidwall/gjson v1.7.1/go.mod h1:5/xDoumyyDNerp2U36lyolv46b3uF/9Bu6OfyQ9GImk= diff --git a/hsm/crypto11_mock_test.go b/hsm/crypto11_mock_test.go new file mode 100644 index 00000000000..d1828c51bba --- /dev/null +++ b/hsm/crypto11_mock_test.go @@ -0,0 +1,97 @@ +//go:build hsm +// +build hsm + +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/ThalesIgnite/crypto11 (interfaces: SignerDecrypter) + +// Package hsm_test is a generated GoMock package. +package hsm_test + +import ( + crypto "crypto" + io "io" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockSignerDecrypter is a mock of SignerDecrypter interface. +type MockSignerDecrypter struct { + ctrl *gomock.Controller + recorder *MockSignerDecrypterMockRecorder +} + +// MockSignerDecrypterMockRecorder is the mock recorder for MockSignerDecrypter. +type MockSignerDecrypterMockRecorder struct { + mock *MockSignerDecrypter +} + +// NewMockSignerDecrypter creates a new mock instance. +func NewMockSignerDecrypter(ctrl *gomock.Controller) *MockSignerDecrypter { + mock := &MockSignerDecrypter{ctrl: ctrl} + mock.recorder = &MockSignerDecrypterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSignerDecrypter) EXPECT() *MockSignerDecrypterMockRecorder { + return m.recorder +} + +// Decrypt mocks base method. +func (m *MockSignerDecrypter) Decrypt(arg0 io.Reader, arg1 []byte, arg2 crypto.DecrypterOpts) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Decrypt", arg0, arg1, arg2) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Decrypt indicates an expected call of Decrypt. +func (mr *MockSignerDecrypterMockRecorder) Decrypt(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Decrypt", reflect.TypeOf((*MockSignerDecrypter)(nil).Decrypt), arg0, arg1, arg2) +} + +// Delete mocks base method. +func (m *MockSignerDecrypter) Delete() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Delete") + ret0, _ := ret[0].(error) + return ret0 +} + +// Delete indicates an expected call of Delete. +func (mr *MockSignerDecrypterMockRecorder) Delete() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockSignerDecrypter)(nil).Delete)) +} + +// Public mocks base method. +func (m *MockSignerDecrypter) Public() crypto.PublicKey { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Public") + ret0, _ := ret[0].(crypto.PublicKey) + return ret0 +} + +// Public indicates an expected call of Public. +func (mr *MockSignerDecrypterMockRecorder) Public() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Public", reflect.TypeOf((*MockSignerDecrypter)(nil).Public)) +} + +// Sign mocks base method. +func (m *MockSignerDecrypter) Sign(arg0 io.Reader, arg1 []byte, arg2 crypto.SignerOpts) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Sign", arg0, arg1, arg2) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Sign indicates an expected call of Sign. +func (mr *MockSignerDecrypterMockRecorder) Sign(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Sign", reflect.TypeOf((*MockSignerDecrypter)(nil).Sign), arg0, arg1, arg2) +} diff --git a/hsm/hsm.go b/hsm/hsm.go new file mode 100644 index 00000000000..92130547270 --- /dev/null +++ b/hsm/hsm.go @@ -0,0 +1,45 @@ +//go:build hsm +// +build hsm + +package hsm + +import ( + "crypto/elliptic" + + "github.com/ThalesIgnite/crypto11" + + "github.com/ory/hydra/driver/config" + "github.com/ory/x/logrusx" +) + +type Context interface { + GenerateRSAKeyPairWithAttributes(public, private crypto11.AttributeSet, bits int) (crypto11.SignerDecrypter, error) + GenerateECDSAKeyPairWithAttributes(public, private crypto11.AttributeSet, curve elliptic.Curve) (crypto11.Signer, error) + FindKeyPair(id []byte, label []byte) (crypto11.Signer, error) + FindKeyPairs(id []byte, label []byte) (signer []crypto11.Signer, err error) + GetAttribute(key interface{}, attribute crypto11.AttributeType) (a *crypto11.Attribute, err error) +} + +func NewContext(c *config.Provider, l *logrusx.Logger) Context { + config11 := &crypto11.Config{ + Path: c.HsmLibraryPath(), + Pin: c.HsmPin(), + } + + if c.HsmTokenLabel() != "" { + config11.TokenLabel = c.HsmTokenLabel() + } else { + config11.SlotNumber = c.HsmSlotNumber() + } + + ctx11, err := crypto11.Configure(config11) + if err != nil { + l.WithError(err).Fatalf("Unable to configure Hardware Security Module. Library path: %s, slot: %v, token label: %s", + c.HsmLibraryPath(), *c.HsmSlotNumber(), c.HsmTokenLabel()) + } else { + l.Info("Hardware Security Module is configured.") + } + + var hsmContext Context = ctx11 + return hsmContext +} diff --git a/hsm/hsm_mock_test.go b/hsm/hsm_mock_test.go new file mode 100644 index 00000000000..65502cf21c9 --- /dev/null +++ b/hsm/hsm_mock_test.go @@ -0,0 +1,114 @@ +//go:build hsm +// +build hsm + +// Code generated by MockGen. DO NOT EDIT. +// Source: hsm/hsm.go + +// Package hsm_test is a generated GoMock package. +package hsm_test + +import ( + elliptic "crypto/elliptic" + reflect "reflect" + + crypto11 "github.com/ThalesIgnite/crypto11" + gomock "github.com/golang/mock/gomock" +) + +// MockContext is a mock of Context interface. +type MockContext struct { + ctrl *gomock.Controller + recorder *MockContextMockRecorder +} + +// MockContextMockRecorder is the mock recorder for MockContext. +type MockContextMockRecorder struct { + mock *MockContext +} + +// NewMockContext creates a new mock instance. +func NewMockContext(ctrl *gomock.Controller) *MockContext { + mock := &MockContext{ctrl: ctrl} + mock.recorder = &MockContextMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockContext) EXPECT() *MockContextMockRecorder { + return m.recorder +} + +// FindKeyPair mocks base method. +func (m *MockContext) FindKeyPair(id, label []byte) (crypto11.Signer, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FindKeyPair", id, label) + ret0, _ := ret[0].(crypto11.Signer) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FindKeyPair indicates an expected call of FindKeyPair. +func (mr *MockContextMockRecorder) FindKeyPair(id, label interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindKeyPair", reflect.TypeOf((*MockContext)(nil).FindKeyPair), id, label) +} + +// FindKeyPairs mocks base method. +func (m *MockContext) FindKeyPairs(id, label []byte) ([]crypto11.Signer, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FindKeyPairs", id, label) + ret0, _ := ret[0].([]crypto11.Signer) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FindKeyPairs indicates an expected call of FindKeyPairs. +func (mr *MockContextMockRecorder) FindKeyPairs(id, label interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindKeyPairs", reflect.TypeOf((*MockContext)(nil).FindKeyPairs), id, label) +} + +// GenerateECDSAKeyPairWithAttributes mocks base method. +func (m *MockContext) GenerateECDSAKeyPairWithAttributes(public, private crypto11.AttributeSet, curve elliptic.Curve) (crypto11.Signer, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GenerateECDSAKeyPairWithAttributes", public, private, curve) + ret0, _ := ret[0].(crypto11.Signer) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GenerateECDSAKeyPairWithAttributes indicates an expected call of GenerateECDSAKeyPairWithAttributes. +func (mr *MockContextMockRecorder) GenerateECDSAKeyPairWithAttributes(public, private, curve interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GenerateECDSAKeyPairWithAttributes", reflect.TypeOf((*MockContext)(nil).GenerateECDSAKeyPairWithAttributes), public, private, curve) +} + +// GenerateRSAKeyPairWithAttributes mocks base method. +func (m *MockContext) GenerateRSAKeyPairWithAttributes(public, private crypto11.AttributeSet, bits int) (crypto11.SignerDecrypter, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GenerateRSAKeyPairWithAttributes", public, private, bits) + ret0, _ := ret[0].(crypto11.SignerDecrypter) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GenerateRSAKeyPairWithAttributes indicates an expected call of GenerateRSAKeyPairWithAttributes. +func (mr *MockContextMockRecorder) GenerateRSAKeyPairWithAttributes(public, private, bits interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GenerateRSAKeyPairWithAttributes", reflect.TypeOf((*MockContext)(nil).GenerateRSAKeyPairWithAttributes), public, private, bits) +} + +// GetAttribute mocks base method. +func (m *MockContext) GetAttribute(key interface{}, attribute crypto11.AttributeType) (*crypto11.Attribute, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAttribute", key, attribute) + ret0, _ := ret[0].(*crypto11.Attribute) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAttribute indicates an expected call of GetAttribute. +func (mr *MockContextMockRecorder) GetAttribute(key, attribute interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAttribute", reflect.TypeOf((*MockContext)(nil).GetAttribute), key, attribute) +} diff --git a/hsm/manager_hsm.go b/hsm/manager_hsm.go new file mode 100644 index 00000000000..928a61ce960 --- /dev/null +++ b/hsm/manager_hsm.go @@ -0,0 +1,311 @@ +//go:build hsm +// +build hsm + +package hsm + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "crypto/x509" + "net/http" + "sync" + + "github.com/pkg/errors" + + "github.com/pborman/uuid" + + "github.com/ory/fosite" + "github.com/ory/hydra/jwk" + + "github.com/miekg/pkcs11" + + "github.com/ory/hydra/x" + + "github.com/ThalesIgnite/crypto11" + "gopkg.in/square/go-jose.v2" + "gopkg.in/square/go-jose.v2/cryptosigner" +) + +type KeyManager struct { + jwk.Manager + sync.RWMutex + Context +} + +var ErrPreGeneratedKeys = &fosite.RFC6749Error{ + CodeField: http.StatusBadRequest, + ErrorField: http.StatusText(http.StatusBadRequest), + DescriptionField: "Cannot add/update pre generated keys on Hardware Security Module", +} + +func NewKeyManager(hsm Context) *KeyManager { + return &KeyManager{ + Context: hsm, + } +} + +func (m *KeyManager) GenerateAndPersistKeySet(_ context.Context, set, kid, alg, use string) (*jose.JSONWebKeySet, error) { + m.Lock() + defer m.Unlock() + + err := m.deleteExistingKeySet(set) + if err != nil { + return nil, err + } + + if len(kid) == 0 { + kid = uuid.New() + } + + privateAttrSet, publicAttrSet, err := getKeyPairAttributes(kid, set, use) + if err != nil { + return nil, err + } + + switch { + case alg == "RS256": + key, err := m.GenerateRSAKeyPairWithAttributes(publicAttrSet, privateAttrSet, 4096) + if err != nil { + return nil, err + } + return createKeySet(key, kid, alg, use) + case alg == "ES256": + key, err := m.GenerateECDSAKeyPairWithAttributes(publicAttrSet, privateAttrSet, elliptic.P256()) + if err != nil { + return nil, err + } + return createKeySet(key, kid, alg, use) + case alg == "ES512": + key, err := m.GenerateECDSAKeyPairWithAttributes(publicAttrSet, privateAttrSet, elliptic.P521()) + if err != nil { + return nil, err + } + return createKeySet(key, kid, alg, use) + + // NOTE: + // - HS256, HS512 not supported. Makes sense only if shared HSM is used between Hydra and authenticating client. + // - EdDSA not supported. As of now PKCS#11 v2.4 doesn't support EdDSA keys using curve Ed25519. However, + // PKCS#11 3.0 (https://docs.oasis-open.org/pkcs11/pkcs11-curr/v3.0/pkcs11-curr-v3.0.html) + // contains support for EdDSA. + + default: + return nil, errors.WithStack(jwk.ErrUnsupportedKeyAlgorithm) + } +} + +func (m *KeyManager) GetKey(_ context.Context, set, kid string) (*jose.JSONWebKeySet, error) { + m.RLock() + defer m.RUnlock() + + keyPair, err := m.FindKeyPair([]byte(kid), []byte(set)) + if err != nil { + return nil, err + } + + if keyPair == nil { + return nil, errors.WithStack(x.ErrNotFound) + } + + id, alg, use, err := getKeySetAttributes(m, keyPair, []byte(kid)) + if err != nil { + return nil, err + } + + return createKeySet(keyPair, id, alg, use) +} + +func (m *KeyManager) GetKeySet(_ context.Context, set string) (*jose.JSONWebKeySet, error) { + m.RLock() + defer m.RUnlock() + + keyPairs, err := m.FindKeyPairs(nil, []byte(set)) + if err != nil { + return nil, err + } + + if keyPairs == nil { + return nil, errors.WithStack(x.ErrNotFound) + } + + var keys []jose.JSONWebKey + for _, keyPair := range keyPairs { + kid, alg, use, err := getKeySetAttributes(m, keyPair, nil) + if err != nil { + return nil, err + } + keys = append(keys, createKeys(keyPair, kid, alg, use)...) + } + + return &jose.JSONWebKeySet{ + Keys: keys, + }, nil +} + +func (m *KeyManager) DeleteKey(_ context.Context, set, kid string) error { + m.Lock() + defer m.Unlock() + + keyPair, err := m.FindKeyPair([]byte(kid), []byte(set)) + if err != nil { + return err + } + + if keyPair != nil { + err = keyPair.Delete() + if err != nil { + return err + } + } else { + return errors.WithStack(x.ErrNotFound) + } + return nil +} + +func (m *KeyManager) DeleteKeySet(_ context.Context, set string) error { + m.Lock() + defer m.Unlock() + + keyPairs, err := m.FindKeyPairs(nil, []byte(set)) + if err != nil { + return err + } + + if keyPairs == nil { + return errors.WithStack(x.ErrNotFound) + } + + for _, keyPair := range keyPairs { + err = keyPair.Delete() + if err != nil { + return err + } + } + return nil +} + +func (m *KeyManager) AddKey(_ context.Context, _ string, _ *jose.JSONWebKey) error { + return errors.WithStack(ErrPreGeneratedKeys) +} + +func (m *KeyManager) AddKeySet(_ context.Context, _ string, _ *jose.JSONWebKeySet) error { + return errors.WithStack(ErrPreGeneratedKeys) +} + +func (m *KeyManager) UpdateKey(_ context.Context, _ string, _ *jose.JSONWebKey) error { + return errors.WithStack(ErrPreGeneratedKeys) +} + +func (m *KeyManager) UpdateKeySet(_ context.Context, _ string, _ *jose.JSONWebKeySet) error { + return errors.WithStack(ErrPreGeneratedKeys) +} + +func getKeySetAttributes(m *KeyManager, key crypto11.Signer, kid []byte) (string, string, string, error) { + if kid == nil { + ckaId, err := m.GetAttribute(key, crypto11.CkaId) + if err != nil { + return "", "", "", err + } + kid = ckaId.Value + } + + var alg string + switch k := key.Public().(type) { + case *rsa.PublicKey: + alg = "RS256" + // TODO Should we validate minimal key length by checking CKA_MODULUS_BITS? + // TODO see https://github.com/ory/hydra/issues/2905 + case *ecdsa.PublicKey: + if k.Curve == elliptic.P521() { + alg = "ES512" + } else if k.Curve == elliptic.P256() { + alg = "ES256" + } else { + return "", "", "", errors.WithStack(jwk.ErrUnsupportedEllipticCurve) + } + default: + return "", "", "", errors.WithStack(jwk.ErrUnsupportedKeyAlgorithm) + } + + use := "sig" + ckaDecrypt, _ := m.GetAttribute(key, crypto11.CkaDecrypt) + if ckaDecrypt != nil && len(ckaDecrypt.Value) != 0 && ckaDecrypt.Value[0] == 0x1 { + use = "enc" + } + return string(kid), alg, use, nil +} + +func getKeyPairAttributes(kid string, set string, use string) (crypto11.AttributeSet, crypto11.AttributeSet, error) { + + privateAttrSet, err := crypto11.NewAttributeSetWithIDAndLabel([]byte(kid), []byte(set)) + if err != nil { + return nil, nil, err + } + + publicAttrSet, err := crypto11.NewAttributeSetWithIDAndLabel([]byte(kid), []byte(set)) + if err != nil { + return nil, nil, err + } + + if len(use) == 0 || use == "sig" { + publicAttrSet.AddIfNotPresent([]*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_VERIFY, true), + pkcs11.NewAttribute(pkcs11.CKA_ENCRYPT, false), + }) + privateAttrSet.AddIfNotPresent([]*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_SIGN, true), + pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, false), + }) + } else { + publicAttrSet.AddIfNotPresent([]*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_VERIFY, false), + pkcs11.NewAttribute(pkcs11.CKA_ENCRYPT, true), + }) + privateAttrSet.AddIfNotPresent([]*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_SIGN, false), + pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true), + }) + } + + return privateAttrSet, publicAttrSet, nil +} + +func (m *KeyManager) deleteExistingKeySet(set string) error { + existingKeyPairs, err := m.FindKeyPairs(nil, []byte(set)) + if err != nil { + return err + } + if len(existingKeyPairs) != 0 { + for _, keyPair := range existingKeyPairs { + _ = keyPair.Delete() + } + } + return nil +} + +func createKeySet(key crypto11.Signer, kid, alg, use string) (*jose.JSONWebKeySet, error) { + return &jose.JSONWebKeySet{ + Keys: createKeys(key, kid, alg, use), + }, nil +} + +func createKeys(key crypto11.Signer, kid, alg, use string) []jose.JSONWebKey { + return []jose.JSONWebKey{{ + Algorithm: alg, + Use: use, + Key: cryptosigner.Opaque(key), + KeyID: kid, + Certificates: []*x509.Certificate{}, + CertificateThumbprintSHA1: []uint8{}, + CertificateThumbprintSHA256: []uint8{}, + }, { + Algorithm: alg, + Use: use, + Key: key.Public(), + KeyID: kid, + Certificates: []*x509.Certificate{}, + CertificateThumbprintSHA1: []uint8{}, + CertificateThumbprintSHA256: []uint8{}, + }} +} diff --git a/hsm/manager_hsm_test.go b/hsm/manager_hsm_test.go new file mode 100644 index 00000000000..a60b80d54d6 --- /dev/null +++ b/hsm/manager_hsm_test.go @@ -0,0 +1,807 @@ +//go:build hsm +// +build hsm + +package hsm_test + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "reflect" + "testing" + + "github.com/ory/hydra/jwk" + + "github.com/ory/hydra/driver" + "github.com/ory/hydra/driver/config" + "github.com/ory/hydra/persistence/sql" + "github.com/ory/x/configx" + "github.com/ory/x/logrusx" + + "github.com/ThalesIgnite/crypto11" + "github.com/golang/mock/gomock" + "github.com/miekg/pkcs11" + "github.com/pborman/uuid" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/square/go-jose.v2" + "gopkg.in/square/go-jose.v2/cryptosigner" + + "github.com/ory/hydra/hsm" + "github.com/ory/hydra/x" +) + +func TestDefaultKeyManager_HsmEnabled(t *testing.T) { + ctrl := gomock.NewController(t) + mockHsmContext := NewMockContext(ctrl) + defer ctrl.Finish() + l := logrusx.New("", "") + c := config.MustNew(l, configx.SkipValidation()) + c.MustSet(config.KeyDSN, "memory") + c.MustSet(config.HsmEnabled, "true") + reg := driver.NewRegistrySQL() + reg.WithLogger(l) + reg.WithConfig(c) + reg.WithHsmContext(mockHsmContext) + err := reg.Init(context.Background()) + assert.NoError(t, err) + assert.IsType(t, &jwk.ManagerStrategy{}, reg.KeyManager()) + assert.IsType(t, &sql.Persister{}, reg.SoftwareKeyManager()) +} + +func TestKeyManager_GenerateAndPersistKeySet(t *testing.T) { + ctrl := gomock.NewController(t) + hsmContext := NewMockContext(ctrl) + defer ctrl.Finish() + + rsaKey, err := rsa.GenerateKey(rand.Reader, 512) + require.NoError(t, err) + + ecdsaKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + require.NoError(t, err) + + rsaKeyPair := NewMockSignerDecrypter(ctrl) + rsaKeyPair.EXPECT().Public().Return(&rsaKey.PublicKey).AnyTimes() + + ecdsaKeyPair := NewMockSignerDecrypter(ctrl) + ecdsaKeyPair.EXPECT().Public().Return(&ecdsaKey.PublicKey).AnyTimes() + + var kid = uuid.New() + + type args struct { + ctx context.Context + set string + kid string + alg string + use string + } + tests := []struct { + name string + setup func(t *testing.T) + args args + want *jose.JSONWebKeySet + wantErrMsg string + wantErr error + }{ + { + name: "Generate RS256", + args: args{ + ctx: context.TODO(), + set: x.OpenIDConnectKeyName, + kid: kid, + alg: "RS256", + use: "sig", + }, + setup: func(t *testing.T) { + privateAttrSet, publicAttrSet := expectedKeyAttributes(t, kid) + hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(nil, nil) + hsmContext.EXPECT().GenerateRSAKeyPairWithAttributes(gomock.Eq(publicAttrSet), gomock.Eq(privateAttrSet), gomock.Eq(4096)).Return(rsaKeyPair, nil) + }, + want: expectedKeySet(rsaKeyPair, kid, "RS256", "sig"), + }, + { + name: "Generate RS256 with GenerateRSAKeyPairWithAttributes Error", + args: args{ + ctx: context.TODO(), + set: x.OpenIDConnectKeyName, + kid: kid, + alg: "RS256", + use: "sig", + }, + setup: func(t *testing.T) { + privateAttrSet, publicAttrSet := expectedKeyAttributes(t, kid) + hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(nil, nil) + hsmContext.EXPECT().GenerateRSAKeyPairWithAttributes(gomock.Eq(publicAttrSet), gomock.Eq(privateAttrSet), gomock.Eq(4096)).Return(nil, errors.New("GenerateRSAKeyPairWithAttributesError")) + }, + wantErrMsg: "GenerateRSAKeyPairWithAttributesError", + }, + { + name: "Generate ES256", + args: args{ + ctx: context.TODO(), + set: x.OpenIDConnectKeyName, + kid: kid, + alg: "ES256", + use: "sig", + }, + setup: func(t *testing.T) { + privateAttrSet, publicAttrSet := expectedKeyAttributes(t, kid) + hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(nil, nil) + hsmContext.EXPECT().GenerateECDSAKeyPairWithAttributes(gomock.Eq(publicAttrSet), gomock.Eq(privateAttrSet), gomock.Eq(elliptic.P256())).Return(ecdsaKeyPair, nil) + }, + want: expectedKeySet(ecdsaKeyPair, kid, "ES256", "sig"), + }, + { + name: "Generate ES256 with GenerateECDSAKeyPairWithAttributes Error", + args: args{ + ctx: context.TODO(), + set: x.OpenIDConnectKeyName, + kid: kid, + alg: "ES256", + use: "sig", + }, + setup: func(t *testing.T) { + privateAttrSet, publicAttrSet := expectedKeyAttributes(t, kid) + hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(nil, nil) + hsmContext.EXPECT().GenerateECDSAKeyPairWithAttributes(gomock.Eq(publicAttrSet), gomock.Eq(privateAttrSet), gomock.Eq(elliptic.P256())).Return(nil, errors.New("GenerateECDSAKeyPairWithAttributesError")) + }, + wantErrMsg: "GenerateECDSAKeyPairWithAttributesError", + }, + { + name: "Generate ES512", + args: args{ + ctx: context.TODO(), + set: x.OpenIDConnectKeyName, + kid: kid, + alg: "ES512", + use: "sig", + }, + setup: func(t *testing.T) { + privateAttrSet, publicAttrSet := expectedKeyAttributes(t, kid) + hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(nil, nil) + hsmContext.EXPECT().GenerateECDSAKeyPairWithAttributes(gomock.Eq(publicAttrSet), gomock.Eq(privateAttrSet), gomock.Eq(elliptic.P521())).Return(ecdsaKeyPair, nil) + }, + want: expectedKeySet(ecdsaKeyPair, kid, "ES512", "sig"), + }, + { + name: "Generate ES512 GenerateECDSAKeyPairWithAttributes Error", + args: args{ + ctx: context.TODO(), + set: x.OpenIDConnectKeyName, + kid: kid, + alg: "ES512", + use: "sig", + }, + setup: func(t *testing.T) { + privateAttrSet, publicAttrSet := expectedKeyAttributes(t, kid) + hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(nil, nil) + hsmContext.EXPECT().GenerateECDSAKeyPairWithAttributes(gomock.Eq(publicAttrSet), gomock.Eq(privateAttrSet), gomock.Eq(elliptic.P521())).Return(nil, errors.New("GenerateECDSAKeyPairWithAttributesError")) + }, + wantErrMsg: "GenerateECDSAKeyPairWithAttributesError", + }, + { + name: "Generate unsupported", + args: args{ + ctx: context.TODO(), + set: x.OpenIDConnectKeyName, + kid: kid, + alg: "ES384", + use: "sig", + }, + setup: func(t *testing.T) { + hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(nil, nil) + }, + wantErr: errors.WithStack(jwk.ErrUnsupportedKeyAlgorithm), + }, + { + name: "Generate with FindKeyPair Error", + args: args{ + ctx: context.TODO(), + set: x.OpenIDConnectKeyName, + kid: kid, + alg: "RS256", + use: "sig", + }, + setup: func(t *testing.T) { + hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(nil, errors.New("FindKeyPairError")) + }, + wantErrMsg: "FindKeyPairError", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setup(t) + m := &hsm.KeyManager{ + Context: hsmContext, + } + got, err := m.GenerateAndPersistKeySet(tt.args.ctx, tt.args.set, tt.args.kid, tt.args.alg, tt.args.use) + if tt.wantErr != nil { + require.Nil(t, got) + require.IsType(t, tt.wantErr, err) + } else if len(tt.wantErrMsg) != 0 { + require.Nil(t, got) + require.EqualError(t, err, tt.wantErrMsg) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("GenerateAndPersistKeySet() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestKeyManager_GetKey(t *testing.T) { + ctrl := gomock.NewController(t) + hsmContext := NewMockContext(ctrl) + defer ctrl.Finish() + + rsaKey, err := rsa.GenerateKey(rand.Reader, 512) + require.NoError(t, err) + rsaKeyPair := NewMockSignerDecrypter(ctrl) + rsaKeyPair.EXPECT().Public().Return(&rsaKey.PublicKey).AnyTimes() + + ecdsaP256Key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + ecdsaP256KeyPair := NewMockSignerDecrypter(ctrl) + ecdsaP256KeyPair.EXPECT().Public().Return(&ecdsaP256Key.PublicKey).AnyTimes() + + ecdsaP521Key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + require.NoError(t, err) + ecdsaP521KeyPair := NewMockSignerDecrypter(ctrl) + ecdsaP521KeyPair.EXPECT().Public().Return(&ecdsaP521Key.PublicKey).AnyTimes() + + ecdsaP224Key, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader) + require.NoError(t, err) + ecdsaP224KeyPair := NewMockSignerDecrypter(ctrl) + ecdsaP224KeyPair.EXPECT().Public().Return(&ecdsaP224Key.PublicKey).AnyTimes() + + var kid = uuid.New() + + type args struct { + ctx context.Context + set string + kid string + } + tests := []struct { + name string + setup func(t *testing.T) + args args + want *jose.JSONWebKeySet + wantErrMsg string + wantErr error + }{ + { + name: "Get RS256 sig", + args: args{ + ctx: context.TODO(), + set: x.OpenIDConnectKeyName, + kid: kid, + }, + setup: func(t *testing.T) { + hsmContext.EXPECT().FindKeyPair(gomock.Eq([]byte(kid)), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(rsaKeyPair, nil) + hsmContext.EXPECT().GetAttribute(gomock.Eq(rsaKeyPair), gomock.Eq(crypto11.CkaDecrypt)).Return(nil, nil) + }, + want: expectedKeySet(rsaKeyPair, kid, "RS256", "sig"), + }, + { + name: "Get RS256 enc", + args: args{ + ctx: context.TODO(), + set: x.OpenIDConnectKeyName, + kid: kid, + }, + setup: func(t *testing.T) { + hsmContext.EXPECT().FindKeyPair(gomock.Eq([]byte(kid)), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(rsaKeyPair, nil) + hsmContext.EXPECT().GetAttribute(gomock.Eq(rsaKeyPair), gomock.Eq(crypto11.CkaDecrypt)).Return(pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true), nil) + }, + want: expectedKeySet(rsaKeyPair, kid, "RS256", "enc"), + }, + { + name: "Key usage attribute error", + args: args{ + ctx: context.TODO(), + set: x.OpenIDConnectKeyName, + kid: kid, + }, + setup: func(t *testing.T) { + hsmContext.EXPECT().FindKeyPair(gomock.Eq([]byte(kid)), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(rsaKeyPair, nil) + hsmContext.EXPECT().GetAttribute(gomock.Eq(rsaKeyPair), gomock.Eq(crypto11.CkaDecrypt)).Return(nil, errors.New("GetAttributeError")) + }, + want: expectedKeySet(rsaKeyPair, kid, "RS256", "sig"), + }, + { + name: "Get ES256 sig", + args: args{ + ctx: context.TODO(), + set: x.OpenIDConnectKeyName, + kid: kid, + }, + setup: func(t *testing.T) { + hsmContext.EXPECT().FindKeyPair(gomock.Eq([]byte(kid)), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(ecdsaP256KeyPair, nil) + hsmContext.EXPECT().GetAttribute(gomock.Eq(ecdsaP256KeyPair), gomock.Eq(crypto11.CkaDecrypt)).Return(nil, nil) + }, + want: expectedKeySet(ecdsaP256KeyPair, kid, "ES256", "sig"), + }, + { + name: "Get ES256 enc", + args: args{ + ctx: context.TODO(), + set: x.OpenIDConnectKeyName, + kid: kid, + }, + setup: func(t *testing.T) { + hsmContext.EXPECT().FindKeyPair(gomock.Eq([]byte(kid)), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(ecdsaP256KeyPair, nil) + hsmContext.EXPECT().GetAttribute(gomock.Eq(ecdsaP256KeyPair), gomock.Eq(crypto11.CkaDecrypt)).Return(pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true), nil) + }, + want: expectedKeySet(ecdsaP256KeyPair, kid, "ES256", "enc"), + }, + { + name: "Get ES512 sig", + args: args{ + ctx: context.TODO(), + set: x.OpenIDConnectKeyName, + kid: kid, + }, + setup: func(t *testing.T) { + hsmContext.EXPECT().FindKeyPair(gomock.Eq([]byte(kid)), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(ecdsaP521KeyPair, nil) + hsmContext.EXPECT().GetAttribute(gomock.Eq(ecdsaP521KeyPair), gomock.Eq(crypto11.CkaDecrypt)).Return(nil, nil) + }, + want: expectedKeySet(ecdsaP521KeyPair, kid, "ES512", "sig"), + }, + { + name: "Get ES512 enc", + args: args{ + ctx: context.TODO(), + set: x.OpenIDConnectKeyName, + kid: kid, + }, + setup: func(t *testing.T) { + hsmContext.EXPECT().FindKeyPair(gomock.Eq([]byte(kid)), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(ecdsaP521KeyPair, nil) + hsmContext.EXPECT().GetAttribute(gomock.Eq(ecdsaP521KeyPair), gomock.Eq(crypto11.CkaDecrypt)).Return(pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true), nil) + }, + want: expectedKeySet(ecdsaP521KeyPair, kid, "ES512", "enc"), + }, + { + name: "Key not found", + args: args{ + ctx: context.TODO(), + set: x.OpenIDConnectKeyName, + kid: kid, + }, + setup: func(t *testing.T) { + hsmContext.EXPECT().FindKeyPair(gomock.Eq([]byte(kid)), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(nil, nil) + }, + wantErrMsg: "Not Found", + }, + { + name: "FindKeyPair Error", + args: args{ + ctx: context.TODO(), + set: x.OpenIDConnectKeyName, + kid: kid, + }, + setup: func(t *testing.T) { + hsmContext.EXPECT().FindKeyPair(gomock.Eq([]byte(kid)), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(nil, errors.New("FindKeyPairError")) + }, + wantErrMsg: "FindKeyPairError", + }, + { + name: "Unsupported elliptic curve", + args: args{ + ctx: context.TODO(), + set: x.OpenIDConnectKeyName, + kid: kid, + }, + setup: func(t *testing.T) { + hsmContext.EXPECT().FindKeyPair(gomock.Eq([]byte(kid)), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(ecdsaP224KeyPair, nil) + }, + wantErr: errors.WithStack(jwk.ErrUnsupportedEllipticCurve), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setup(t) + m := &hsm.KeyManager{ + Context: hsmContext, + } + got, err := m.GetKey(tt.args.ctx, tt.args.set, tt.args.kid) + if tt.wantErr != nil { + require.Nil(t, got) + require.IsType(t, tt.wantErr, err) + } else if len(tt.wantErrMsg) != 0 { + require.Nil(t, got) + require.EqualError(t, err, tt.wantErrMsg) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("GetKey() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestKeyManager_GetKeySet(t *testing.T) { + ctrl := gomock.NewController(t) + hsmContext := NewMockContext(ctrl) + defer ctrl.Finish() + + rsaKey, err := rsa.GenerateKey(rand.Reader, 512) + require.NoError(t, err) + rsaKid := uuid.New() + rsaKeyPair := NewMockSignerDecrypter(ctrl) + rsaKeyPair.EXPECT().Public().Return(&rsaKey.PublicKey).AnyTimes() + + ecdsaP256Key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + ecdsaP256Kid := uuid.New() + ecdsaP256KeyPair := NewMockSignerDecrypter(ctrl) + ecdsaP256KeyPair.EXPECT().Public().Return(&ecdsaP256Key.PublicKey).AnyTimes() + + ecdsaP521Key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + require.NoError(t, err) + ecdsaP521Kid := uuid.New() + ecdsaP521KeyPair := NewMockSignerDecrypter(ctrl) + ecdsaP521KeyPair.EXPECT().Public().Return(&ecdsaP521Key.PublicKey).AnyTimes() + + ecdsaP224Key, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader) + require.NoError(t, err) + ecdsaP224Kid := uuid.New() + ecdsaP224KeyPair := NewMockSignerDecrypter(ctrl) + ecdsaP224KeyPair.EXPECT().Public().Return(&ecdsaP224Key.PublicKey).AnyTimes() + + allKeys := []crypto11.Signer{rsaKeyPair, ecdsaP256KeyPair, ecdsaP521KeyPair} + + var keys []jose.JSONWebKey + keys = append(keys, createJSONWebKeys(rsaKeyPair, rsaKid, "RS256", "sig")...) + keys = append(keys, createJSONWebKeys(ecdsaP256KeyPair, ecdsaP256Kid, "ES256", "sig")...) + keys = append(keys, createJSONWebKeys(ecdsaP521KeyPair, ecdsaP521Kid, "ES512", "sig")...) + + type args struct { + ctx context.Context + set string + } + tests := []struct { + name string + setup func(t *testing.T) + args args + want *jose.JSONWebKeySet + wantErrMsg string + wantErr error + }{ + { + name: "With multiple keys per set", + args: args{ + ctx: context.TODO(), + set: x.OpenIDConnectKeyName, + }, + setup: func(t *testing.T) { + hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(allKeys, nil) + hsmContext.EXPECT().GetAttribute(gomock.Eq(rsaKeyPair), gomock.Eq(crypto11.CkaId)).Return(pkcs11.NewAttribute(pkcs11.CKA_ID, []byte(rsaKid)), nil) + hsmContext.EXPECT().GetAttribute(gomock.Eq(rsaKeyPair), gomock.Eq(crypto11.CkaDecrypt)).Return(nil, nil) + hsmContext.EXPECT().GetAttribute(gomock.Eq(ecdsaP256KeyPair), gomock.Eq(crypto11.CkaId)).Return(pkcs11.NewAttribute(pkcs11.CKA_ID, []byte(ecdsaP256Kid)), nil) + hsmContext.EXPECT().GetAttribute(gomock.Eq(ecdsaP256KeyPair), gomock.Eq(crypto11.CkaDecrypt)).Return(nil, nil) + hsmContext.EXPECT().GetAttribute(gomock.Eq(ecdsaP521KeyPair), gomock.Eq(crypto11.CkaId)).Return(pkcs11.NewAttribute(pkcs11.CKA_ID, []byte(ecdsaP521Kid)), nil) + hsmContext.EXPECT().GetAttribute(gomock.Eq(ecdsaP521KeyPair), gomock.Eq(crypto11.CkaDecrypt)).Return(nil, nil) + }, + want: &jose.JSONWebKeySet{Keys: keys}, + }, + { + name: "GetCkaIdAttributeError Error", + args: args{ + ctx: context.TODO(), + set: x.OpenIDConnectKeyName, + }, + setup: func(t *testing.T) { + hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(allKeys, nil) + hsmContext.EXPECT().GetAttribute(gomock.Eq(rsaKeyPair), gomock.Eq(crypto11.CkaId)).Return(nil, errors.New("GetCkaIdAttributeError")) + }, + wantErrMsg: "GetCkaIdAttributeError", + }, + { + name: "Key set not found", + args: args{ + ctx: context.TODO(), + set: x.OpenIDConnectKeyName, + }, + setup: func(t *testing.T) { + hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(nil, nil) + }, + wantErrMsg: "Not Found", + }, + { + name: "FindKeyPairs Error", + args: args{ + ctx: context.TODO(), + set: x.OpenIDConnectKeyName, + }, + setup: func(t *testing.T) { + hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(nil, errors.New("FindKeyPairsError")) + }, + wantErrMsg: "FindKeyPairsError", + }, + { + name: "Unsupported elliptic curve", + args: args{ + ctx: context.TODO(), + set: x.OpenIDConnectKeyName, + }, + setup: func(t *testing.T) { + hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return([]crypto11.Signer{ecdsaP224KeyPair}, nil) + hsmContext.EXPECT().GetAttribute(gomock.Eq(ecdsaP224KeyPair), gomock.Eq(crypto11.CkaId)).Return(pkcs11.NewAttribute(pkcs11.CKA_ID, []byte(ecdsaP224Kid)), nil) + }, + wantErr: errors.WithStack(jwk.ErrUnsupportedEllipticCurve), + }, + { + name: "Invalid key type Error", + args: args{ + ctx: context.TODO(), + set: x.OpenIDConnectKeyName, + }, + setup: func(t *testing.T) { + keyPair := NewMockSignerDecrypter(ctrl) + hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return([]crypto11.Signer{keyPair}, nil) + hsmContext.EXPECT().GetAttribute(gomock.Eq(keyPair), gomock.Eq(crypto11.CkaId)).Return(pkcs11.NewAttribute(pkcs11.CKA_ID, []byte(rsaKid)), nil) + keyPair.EXPECT().Public().Return(nil).Times(1) + }, + wantErr: errors.WithStack(jwk.ErrUnsupportedKeyAlgorithm), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setup(t) + m := &hsm.KeyManager{ + Context: hsmContext, + } + got, err := m.GetKeySet(tt.args.ctx, tt.args.set) + if tt.wantErr != nil { + require.Nil(t, got) + require.IsType(t, tt.wantErr, err) + } else if len(tt.wantErrMsg) != 0 { + require.Nil(t, got) + require.EqualError(t, err, tt.wantErrMsg) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("GetKey() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestKeyManager_DeleteKey(t *testing.T) { + ctrl := gomock.NewController(t) + hsmContext := NewMockContext(ctrl) + defer ctrl.Finish() + + rsaKeyPair := NewMockSignerDecrypter(ctrl) + + kid := uuid.New() + + type args struct { + ctx context.Context + set string + kid string + } + tests := []struct { + name string + setup func(t *testing.T) + args args + wantErrMsg string + }{ + { + name: "Existing key", + args: args{ + ctx: context.TODO(), + set: x.OpenIDConnectKeyName, + kid: kid, + }, + setup: func(t *testing.T) { + hsmContext.EXPECT().FindKeyPair(gomock.Eq([]byte(kid)), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(rsaKeyPair, nil) + rsaKeyPair.EXPECT().Delete().Return(nil) + }, + }, + { + name: "Key not found", + args: args{ + ctx: context.TODO(), + set: x.OpenIDConnectKeyName, + kid: kid, + }, + setup: func(t *testing.T) { + hsmContext.EXPECT().FindKeyPair(gomock.Eq([]byte(kid)), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(nil, nil) + }, + wantErrMsg: "Not Found", + }, + { + name: "FindKeyPair Error", + args: args{ + ctx: context.TODO(), + set: x.OpenIDConnectKeyName, + kid: kid, + }, + setup: func(t *testing.T) { + hsmContext.EXPECT().FindKeyPair(gomock.Eq([]byte(kid)), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(nil, errors.New("FindKeyPairError")) + }, + wantErrMsg: "FindKeyPairError", + }, + { + name: "Delete Error", + args: args{ + ctx: context.TODO(), + set: x.OpenIDConnectKeyName, + kid: kid, + }, + setup: func(t *testing.T) { + hsmContext.EXPECT().FindKeyPair(gomock.Eq([]byte(kid)), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(rsaKeyPair, nil) + rsaKeyPair.EXPECT().Delete().Return(errors.New("DeleteError")) + }, + wantErrMsg: "DeleteError", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setup(t) + m := &hsm.KeyManager{ + Context: hsmContext, + } + if err := m.DeleteKey(tt.args.ctx, tt.args.set, tt.args.kid); len(tt.wantErrMsg) != 0 { + require.EqualError(t, err, tt.wantErrMsg) + } + }) + } +} + +func TestKeyManager_DeleteKeySet(t *testing.T) { + ctrl := gomock.NewController(t) + hsmContext := NewMockContext(ctrl) + defer ctrl.Finish() + + rsaKeyPair1 := NewMockSignerDecrypter(ctrl) + rsaKeyPair2 := NewMockSignerDecrypter(ctrl) + allKeys := []crypto11.Signer{rsaKeyPair1, rsaKeyPair2} + + type args struct { + ctx context.Context + set string + } + tests := []struct { + name string + setup func(t *testing.T) + args args + wantErrMsg string + }{ + { + name: "Existing key", + args: args{ + ctx: context.TODO(), + set: x.OpenIDConnectKeyName, + }, + setup: func(t *testing.T) { + hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(allKeys, nil) + rsaKeyPair1.EXPECT().Delete().Return(nil) + rsaKeyPair2.EXPECT().Delete().Return(nil) + }, + }, + { + name: "Key not found", + args: args{ + ctx: context.TODO(), + set: x.OpenIDConnectKeyName, + }, + setup: func(t *testing.T) { + hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(nil, nil) + }, + wantErrMsg: "Not Found", + }, + { + name: "FindKeyPairs Error", + args: args{ + ctx: context.TODO(), + set: x.OpenIDConnectKeyName, + }, + setup: func(t *testing.T) { + hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(nil, errors.New("FindKeyPairsError")) + }, + wantErrMsg: "FindKeyPairsError", + }, + { + name: "Delete Error", + args: args{ + ctx: context.TODO(), + set: x.OpenIDConnectKeyName, + }, + setup: func(t *testing.T) { + hsmContext.EXPECT().FindKeyPairs(gomock.Nil(), gomock.Eq([]byte(x.OpenIDConnectKeyName))).Return(allKeys, nil) + rsaKeyPair1.EXPECT().Delete().Return(errors.New("DeleteError")) + }, + wantErrMsg: "DeleteError", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setup(t) + m := &hsm.KeyManager{ + Context: hsmContext, + } + if err := m.DeleteKeySet(tt.args.ctx, tt.args.set); len(tt.wantErrMsg) != 0 { + require.EqualError(t, err, tt.wantErrMsg) + } + }) + } +} + +func TestKeyManager_AddKey(t *testing.T) { + m := &hsm.KeyManager{ + Context: nil, + } + err := m.AddKey(context.TODO(), x.OpenIDConnectKeyName, &jose.JSONWebKey{}) + assert.ErrorIs(t, err, hsm.ErrPreGeneratedKeys) +} + +func TestKeyManager_AddKeySet(t *testing.T) { + m := &hsm.KeyManager{ + Context: nil, + } + err := m.AddKeySet(context.TODO(), x.OpenIDConnectKeyName, &jose.JSONWebKeySet{}) + assert.ErrorIs(t, err, hsm.ErrPreGeneratedKeys) +} + +func TestKeyManager_UpdateKey(t *testing.T) { + m := &hsm.KeyManager{ + Context: nil, + } + err := m.UpdateKey(context.TODO(), x.OpenIDConnectKeyName, &jose.JSONWebKey{}) + assert.ErrorIs(t, err, hsm.ErrPreGeneratedKeys) +} + +func TestKeyManager_UpdateKeySet(t *testing.T) { + m := &hsm.KeyManager{ + Context: nil, + } + err := m.UpdateKeySet(context.TODO(), x.OpenIDConnectKeyName, &jose.JSONWebKeySet{}) + assert.ErrorIs(t, err, hsm.ErrPreGeneratedKeys) +} + +func expectedKeyAttributes(t *testing.T, kid string) (crypto11.AttributeSet, crypto11.AttributeSet) { + privateAttrSet, err := crypto11.NewAttributeSetWithIDAndLabel([]byte(kid), []byte(x.OpenIDConnectKeyName)) + require.NoError(t, err) + publicAttrSet, err := crypto11.NewAttributeSetWithIDAndLabel([]byte(kid), []byte(x.OpenIDConnectKeyName)) + require.NoError(t, err) + publicAttrSet.AddIfNotPresent([]*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_VERIFY, true), + pkcs11.NewAttribute(pkcs11.CKA_ENCRYPT, false), + }) + privateAttrSet.AddIfNotPresent([]*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_SIGN, true), + pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, false), + }) + return privateAttrSet, publicAttrSet +} + +func expectedKeySet(keyPair *MockSignerDecrypter, kid, alg, use string) *jose.JSONWebKeySet { + return &jose.JSONWebKeySet{Keys: createJSONWebKeys(keyPair, kid, alg, use)} +} + +func createJSONWebKeys(keyPair *MockSignerDecrypter, kid string, alg string, use string) []jose.JSONWebKey { + return []jose.JSONWebKey{{ + Algorithm: alg, + Use: use, + Key: cryptosigner.Opaque(keyPair), + KeyID: kid, + Certificates: []*x509.Certificate{}, + CertificateThumbprintSHA1: []uint8{}, + CertificateThumbprintSHA256: []uint8{}, + }, { + Algorithm: alg, + Use: use, + Key: keyPair.Public(), + KeyID: kid, + Certificates: []*x509.Certificate{}, + CertificateThumbprintSHA1: []uint8{}, + CertificateThumbprintSHA256: []uint8{}, + }} +} diff --git a/hsm/manager_nohsm.go b/hsm/manager_nohsm.go new file mode 100644 index 00000000000..2e20d9a7891 --- /dev/null +++ b/hsm/manager_nohsm.go @@ -0,0 +1,74 @@ +//go:build !hsm +// +build !hsm + +package hsm + +import ( + "context" + "sync" + + "github.com/ory/hydra/driver/config" + "github.com/ory/x/logrusx" + + "github.com/pkg/errors" + + "github.com/ory/hydra/jwk" + + "gopkg.in/square/go-jose.v2" +) + +type Context interface { +} + +type KeyManager struct { + jwk.Manager + sync.RWMutex + Context +} + +var ErrOpSysNotSupported = errors.New("Hardware Security Module is not supported on this platform.") + +func NewContext(c *config.Provider, l *logrusx.Logger) Context { + l.Fatalf("Hardware Security Module is not supported on this platform.") + return nil +} + +func NewKeyManager(hsm Context) *KeyManager { + return nil +} + +func (m *KeyManager) GenerateAndPersistKeySet(_ context.Context, set, kid, alg, use string) (*jose.JSONWebKeySet, error) { + return nil, errors.WithStack(ErrOpSysNotSupported) +} + +func (m *KeyManager) GetKey(_ context.Context, set, kid string) (*jose.JSONWebKeySet, error) { + return nil, errors.WithStack(ErrOpSysNotSupported) +} + +func (m *KeyManager) GetKeySet(_ context.Context, set string) (*jose.JSONWebKeySet, error) { + return nil, errors.WithStack(ErrOpSysNotSupported) +} + +func (m *KeyManager) DeleteKey(_ context.Context, set, kid string) error { + return errors.WithStack(ErrOpSysNotSupported) +} + +func (m *KeyManager) DeleteKeySet(_ context.Context, set string) error { + return errors.WithStack(ErrOpSysNotSupported) +} + +func (m *KeyManager) AddKey(_ context.Context, _ string, _ *jose.JSONWebKey) error { + return errors.WithStack(ErrOpSysNotSupported) +} + +func (m *KeyManager) AddKeySet(_ context.Context, _ string, _ *jose.JSONWebKeySet) error { + return errors.WithStack(ErrOpSysNotSupported) +} + +func (m *KeyManager) UpdateKey(_ context.Context, _ string, _ *jose.JSONWebKey) error { + return errors.WithStack(ErrOpSysNotSupported) +} + +func (m *KeyManager) UpdateKeySet(_ context.Context, _ string, _ *jose.JSONWebKeySet) error { + return errors.WithStack(ErrOpSysNotSupported) +} diff --git a/internal/.hydra.yaml b/internal/.hydra.yaml index a94c722a7b0..63ed16b035e 100644 --- a/internal/.hydra.yaml +++ b/internal/.hydra.yaml @@ -64,6 +64,9 @@ serve: dsn: memory +hsm: + enabled: false + webfinger: jwks: broadcast_keys: diff --git a/internal/config/config.yaml b/internal/config/config.yaml index dc6aeb2639b..97285fa24fd 100644 --- a/internal/config/config.yaml +++ b/internal/config/config.yaml @@ -267,6 +267,15 @@ dsn: memory # dsn: postgres://user:password@host:123/database # dsn: mysql://user:password@tcp(host:123)/database +# hsm configures Hardware Security Module for hydra.openid.id-token, hydra.jwt.access-token keys +# Either slot or token_label must be set. If token_label is set, then first slot in index with this label is used. +hsm: + enabled: false + library: /path/to/hsm-vendor/library.so + pin: token-pin-code + slot: 0 + token_label: hydra + # webfinger configures ./well-known/ settings webfinger: # jwks configures the /.well-known/jwks.json endpoint. diff --git a/internal/driver.go b/internal/driver.go index aa06527886d..ffb52c71471 100644 --- a/internal/driver.go +++ b/internal/driver.go @@ -58,6 +58,18 @@ func newRegistryDefault(t *testing.T, url string, c *config.Provider) driver.Reg c.MustSet(config.KeyDSN, url) r, err := driver.NewRegistryFromDSN(context.Background(), c, logrusx.New("test_hydra", "master")) + + kg := map[string]jwk.KeyGenerator{ + "RS256": new(veryInsecureRS256Generator), + "ES256": &jwk.ECDSA256Generator{}, + "ES512": &jwk.ECDSA512Generator{}, + "EdDSA": &jwk.EdDSAGenerator{}, + "HS256": &jwk.HS256Generator{}, + "HS512": &jwk.HS512Generator{}, + } + + r = r.WithKeyGenerators(kg) + require.NoError(t, err) require.NoError(t, r.Init(context.Background())) @@ -144,7 +156,7 @@ func ConnectDatabases(t *testing.T) (pg, mysql, crdb driver.Registry, clean func } func MustEnsureRegistryKeys(r driver.Registry, key string) { - if err := jwk.EnsureAsymmetricKeypairExists(context.Background(), r, new(veryInsecureRS256Generator), key); err != nil { + if err := jwk.EnsureAsymmetricKeypairExists(context.Background(), r, "RS256", key); err != nil { panic(err) } } diff --git a/jwk/handler.go b/jwk/handler.go index ae918556764..cac7143e9b1 100644 --- a/jwk/handler.go +++ b/jwk/handler.go @@ -34,7 +34,6 @@ import ( "github.com/ory/hydra/x" "github.com/julienschmidt/httprouter" - "github.com/pkg/errors" jose "gopkg.in/square/go-jose.v2" ) @@ -130,6 +129,7 @@ func (h *Handler) GetKey(w http.ResponseWriter, r *http.Request, ps httprouter.P h.r.Writer().WriteError(w, r, err) return } + keys = ExcludeOpaquePrivateKeys(keys) h.r.Writer().Write(w, r, keys) } @@ -163,6 +163,7 @@ func (h *Handler) GetKeySet(w http.ResponseWriter, r *http.Request, ps httproute h.r.Writer().WriteError(w, r, err) return } + keys = ExcludeOpaquePrivateKeys(keys) h.r.Writer().Write(w, r, keys) } @@ -196,24 +197,12 @@ func (h *Handler) Create(w http.ResponseWriter, r *http.Request, ps httprouter.P h.r.Writer().WriteError(w, r, errorsx.WithStack(err)) } - generator, found := h.r.KeyGenerators()[keyRequest.Algorithm] - if !found { - h.r.Writer().WriteErrorCode(w, r, http.StatusBadRequest, errors.Errorf("Generator %s unknown", keyRequest.Algorithm)) - return - } - - keys, err := generator.Generate(keyRequest.KeyID, keyRequest.Use) - if err != nil { + if keys, err := h.r.KeyManager().GenerateAndPersistKeySet(r.Context(), set, keyRequest.KeyID, keyRequest.Algorithm, keyRequest.Use); err == nil { + keys = ExcludeOpaquePrivateKeys(keys) + h.r.Writer().WriteCreated(w, r, fmt.Sprintf("%s://%s/keys/%s", r.URL.Scheme, r.URL.Host, set), keys) + } else { h.r.Writer().WriteError(w, r, err) - return - } - - if err := h.r.KeyManager().AddKeySet(r.Context(), set, keys); err != nil { - h.r.Writer().WriteError(w, r, err) - return } - - h.r.Writer().WriteCreated(w, r, fmt.Sprintf("%s://%s/keys/%s", r.URL.Scheme, r.URL.Host, set), keys) } // swagger:route PUT /keys/{set} admin updateJsonWebKeySet @@ -246,12 +235,7 @@ func (h *Handler) UpdateKeySet(w http.ResponseWriter, r *http.Request, ps httpro return } - if err := h.r.KeyManager().DeleteKeySet(r.Context(), set); err != nil { - h.r.Writer().WriteError(w, r, err) - return - } - - if err := h.r.KeyManager().AddKeySet(r.Context(), set, &keySet); err != nil { + if err := h.r.KeyManager().UpdateKeySet(r.Context(), set, &keySet); err != nil { h.r.Writer().WriteError(w, r, err) return } @@ -289,12 +273,7 @@ func (h *Handler) UpdateKey(w http.ResponseWriter, r *http.Request, ps httproute return } - if err := h.r.KeyManager().DeleteKey(r.Context(), set, key.KeyID); err != nil { - h.r.Writer().WriteError(w, r, err) - return - } - - if err := h.r.KeyManager().AddKey(r.Context(), set, &key); err != nil { + if err := h.r.KeyManager().UpdateKey(r.Context(), set, &key); err != nil { h.r.Writer().WriteError(w, r, err) return } diff --git a/jwk/handler_test.go b/jwk/handler_test.go index 7492fac852e..87451197375 100644 --- a/jwk/handler_test.go +++ b/jwk/handler_test.go @@ -53,6 +53,9 @@ func TestHandlerWellKnown(t *testing.T) { JWKPath := "/.well-known/jwks.json" t.Run("Test_Handler_WellKnown/Run_public_key_With_public_prefix", func(t *testing.T) { + if conf.HsmEnabled() { + t.Skip("Skipping test. Not applicable when Hardware Security Module is enabled. Public/private keys on HSM are generated with equal key id's and are not using prefixes") + } IDKS, _ := testGenerator.Generate("test-id-1", "sig") require.NoError(t, reg.KeyManager().AddKeySet(context.TODO(), x.OpenIDConnectKeyName, IDKS)) res, err := http.Get(testServer.URL + JWKPath) @@ -65,21 +68,30 @@ func TestHandlerWellKnown(t *testing.T) { require.Len(t, known.Keys, 1) - resp := known.Key("public:test-id-1") - require.NotNil(t, resp, "Could not find key public") + knownKey := known.Key("public:test-id-1")[0] + require.NotNil(t, knownKey, "Could not find key public") - assert.EqualValues(t, canonicalizeThumbprints(resp), canonicalizeThumbprints(IDKS.Key("public:test-id-1"))) + expectedKey, err := jwk.FindPublicKey(IDKS) + require.NoError(t, err) + assert.EqualValues(t, canonicalizeThumbprints(*expectedKey), canonicalizeThumbprints(knownKey)) require.NoError(t, reg.KeyManager().DeleteKeySet(context.TODO(), x.OpenIDConnectKeyName)) }) t.Run("Test_Handler_WellKnown/Run_public_key_Without_public_prefix", func(t *testing.T) { - IDKS, _ := testGenerator.Generate("test-id-2", "sig") - if strings.ContainsAny(IDKS.Keys[1].KeyID, "public") { - IDKS.Keys[1].KeyID = "test-id-2" + var IDKS *jose.JSONWebKeySet + + if conf.HsmEnabled() { + IDKS, _ = reg.KeyManager().GenerateAndPersistKeySet(context.TODO(), x.OpenIDConnectKeyName, "test-id-2", "RS256", "sig") } else { - IDKS.Keys[0].KeyID = "test-id-2" + IDKS, _ = testGenerator.Generate("test-id-2", "sig") + if strings.ContainsAny(IDKS.Keys[1].KeyID, "public") { + IDKS.Keys[1].KeyID = "test-id-2" + } else { + IDKS.Keys[0].KeyID = "test-id-2" + } + require.NoError(t, reg.KeyManager().AddKeySet(context.TODO(), x.OpenIDConnectKeyName, IDKS)) } - require.NoError(t, reg.KeyManager().AddKeySet(context.TODO(), x.OpenIDConnectKeyName, IDKS)) + res, err := http.Get(testServer.URL + JWKPath) require.NoError(t, err, "problem in http request") defer res.Body.Close() @@ -89,22 +101,21 @@ func TestHandlerWellKnown(t *testing.T) { require.NoError(t, err, "problem in decoding response") require.Len(t, known.Keys, 1) - resp := known.Key("test-id-2") - require.NotNil(t, resp, "Could not find key public") + knownKey := known.Key("test-id-2")[0] + require.NotNil(t, knownKey, "Could not find key public") - assert.EqualValues(t, canonicalizeThumbprints(resp), canonicalizeThumbprints(IDKS.Key("test-id-2"))) + expectedKey, err := jwk.FindPublicKey(IDKS) + require.NoError(t, err) + assert.EqualValues(t, canonicalizeThumbprints(*expectedKey), canonicalizeThumbprints(knownKey)) }) } -func canonicalizeThumbprints(js []jose.JSONWebKey) []jose.JSONWebKey { - for k, v := range js { - if len(v.CertificateThumbprintSHA1) == 0 { - v.CertificateThumbprintSHA1 = nil - } - if len(v.CertificateThumbprintSHA256) == 0 { - v.CertificateThumbprintSHA256 = nil - } - js[k] = v +func canonicalizeThumbprints(js jose.JSONWebKey) jose.JSONWebKey { + if len(js.CertificateThumbprintSHA1) == 0 { + js.CertificateThumbprintSHA1 = nil + } + if len(js.CertificateThumbprintSHA256) == 0 { + js.CertificateThumbprintSHA256 = nil } return js } diff --git a/jwk/helper.go b/jwk/helper.go index 05556b77055..73e29c3fdb2 100644 --- a/jwk/helper.go +++ b/jwk/helper.go @@ -38,71 +38,47 @@ import ( jose "gopkg.in/square/go-jose.v2" ) -func EnsureAsymmetricKeypairExists(ctx context.Context, r InternalRegistry, g KeyGenerator, set string) error { - _, _, err := AsymmetricKeypair(ctx, r, g, set) +func EnsureAsymmetricKeypairExists(ctx context.Context, r InternalRegistry, alg, set string) error { + _, _, err := GetOrGenerateKeys(ctx, r, r.KeyManager(), set, set, alg) return err } -func AsymmetricKeypair(ctx context.Context, r InternalRegistry, g KeyGenerator, set string) (public, private *jose.JSONWebKey, err error) { - priv, err := GetOrCreateKey(ctx, r, g, set, "private") - if err != nil { - return nil, nil, err - } - - pub, err := GetOrCreateKey(ctx, r, g, set, "public") - if err != nil { - return nil, nil, err - } - - return pub, priv, nil -} - -func GetOrCreateKey(ctx context.Context, r InternalRegistry, g KeyGenerator, set, prefix string) (*jose.JSONWebKey, error) { - keys, err := r.KeyManager().GetKeySet(ctx, set) +func GetOrGenerateKeys(ctx context.Context, r InternalRegistry, m Manager, set, kid, alg string) (public, private *jose.JSONWebKey, err error) { + keys, err := m.GetKeySet(ctx, set) if errors.Is(err, x.ErrNotFound) || keys != nil && len(keys.Keys) == 0 { r.Logger().Warnf("JSON Web Key Set \"%s\" does not exist yet, generating new key pair...", set) - keys, err = createKey(ctx, r, g, set) + keys, err = m.GenerateAndPersistKeySet(ctx, set, kid, alg, "sig") if err != nil { - return nil, err + return nil, nil, err } } else if err != nil { - return nil, err + return nil, nil, err } - key, err := FindKeyByPrefix(keys, prefix) - if err != nil { - r.Logger().Warnf("JSON Web Key with prefix %s not found in JSON Web Key Set %s, generating new key pair...", prefix, set) - - keys, err = createKey(ctx, r, g, set) + pubKey, pubKeyErr := FindPublicKey(keys) + privKey, privKeyErr := FindPrivateKey(keys) + if pubKeyErr == nil && privKeyErr == nil { + return pubKey, privKey, nil + } else { + if pubKeyErr != nil { + r.Logger().Warnf("Public JSON Web Key not found in JSON Web Key Set %s, generating new key pair...", set) + } else { + r.Logger().Warnf("Private JSON Web Key not found in JSON Web Key Set %s, generating new key pair...", set) + } + keys, err = m.GenerateAndPersistKeySet(ctx, set, kid, alg, "sig") if err != nil { - return nil, err + return nil, nil, err } - - key, err = FindKeyByPrefix(keys, prefix) + pubKey, err := FindPublicKey(keys) if err != nil { - return nil, err + return nil, nil, err } + privKey, err := FindPrivateKey(keys) + if err != nil { + return nil, nil, err + } + return pubKey, privKey, nil } - - return key, nil -} - -func createKey(ctx context.Context, r InternalRegistry, g KeyGenerator, set string) (*jose.JSONWebKeySet, error) { - keys, err := g.Generate(uuid.New(), "sig") - if err != nil { - return nil, errors.Wrapf(err, "Could not generate JSON Web Key Set \"%s\".", set) - } - - for i, k := range keys.Keys { - k.Use = "sig" - keys.Keys[i] = k - } - - if err = r.KeyManager().AddKeySet(ctx, set, keys); err != nil { - return nil, errors.Wrapf(err, "Could not persist JSON Web Key Set \"%s\".", set) - } - - return keys, nil } func First(keys []jose.JSONWebKey) *jose.JSONWebKey { @@ -112,17 +88,17 @@ func First(keys []jose.JSONWebKey) *jose.JSONWebKey { return &keys[0] } -func FindKeyByPrefix(set *jose.JSONWebKeySet, prefix string) (key *jose.JSONWebKey, err error) { - keys, err := FindKeysByPrefix(set, prefix) - if err != nil { - return nil, err +func FindPublicKey(set *jose.JSONWebKeySet) (key *jose.JSONWebKey, err error) { + keys := ExcludePrivateKeys(set) + if len(keys.Keys) == 0 { + return nil, errors.New("key not found") } return First(keys.Keys), nil } -func FindPublicKey(set *jose.JSONWebKeySet) (key *jose.JSONWebKey, err error) { - keys := ExcludePrivateKeys(set) +func FindPrivateKey(set *jose.JSONWebKeySet) (key *jose.JSONWebKey, err error) { + keys := ExcludePublicKeys(set) if len(keys.Keys) == 0 { return nil, errors.New("key not found") } @@ -130,7 +106,7 @@ func FindPublicKey(set *jose.JSONWebKeySet) (key *jose.JSONWebKey, err error) { return First(keys.Keys), nil } -func ExcludePrivateKeys(set *jose.JSONWebKeySet) *jose.JSONWebKeySet { +func ExcludePublicKeys(set *jose.JSONWebKeySet) *jose.JSONWebKeySet { keys := new(jose.JSONWebKeySet) for _, k := range set.Keys { @@ -138,27 +114,37 @@ func ExcludePrivateKeys(set *jose.JSONWebKeySet) *jose.JSONWebKeySet { _, ed25519OK := k.Key.(ed25519.PublicKey) _, rsaOK := k.Key.(*rsa.PublicKey) - if ecdsaOk || ed25519OK || rsaOK { + if !ecdsaOk && !ed25519OK && !rsaOK { keys.Keys = append(keys.Keys, k) } } return keys } -func FindKeysByPrefix(set *jose.JSONWebKeySet, prefix string) (*jose.JSONWebKeySet, error) { +func ExcludePrivateKeys(set *jose.JSONWebKeySet) *jose.JSONWebKeySet { keys := new(jose.JSONWebKeySet) for _, k := range set.Keys { - if len(k.KeyID) >= len(prefix)+1 && k.KeyID[:len(prefix)+1] == prefix+":" { + _, ecdsaOk := k.Key.(*ecdsa.PublicKey) + _, ed25519OK := k.Key.(ed25519.PublicKey) + _, rsaOK := k.Key.(*rsa.PublicKey) + + if ecdsaOk || ed25519OK || rsaOK { keys.Keys = append(keys.Keys, k) } } + return keys +} - if len(keys.Keys) == 0 { - return nil, errors.Errorf("Unable to find key with prefix %s in JSON Web Key Set", prefix) - } +func ExcludeOpaquePrivateKeys(set *jose.JSONWebKeySet) *jose.JSONWebKeySet { + keys := new(jose.JSONWebKeySet) - return keys, nil + for _, k := range set.Keys { + if _, opaque := k.Key.(jose.OpaqueSigner); !opaque { + keys.Keys = append(keys.Keys, k) + } + } + return keys } func PEMBlockForKey(key interface{}) (*pem.Block, error) { diff --git a/jwk/helper_test.go b/jwk/helper_test.go index a54813a4dd6..2b875d64990 100644 --- a/jwk/helper_test.go +++ b/jwk/helper_test.go @@ -18,89 +18,245 @@ * @license Apache-2.0 */ -package jwk +package jwk_test import ( + "context" + "crypto/dsa" "crypto/ecdsa" "crypto/ed25519" "crypto/rsa" + "encoding/pem" + "strings" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - jose "gopkg.in/square/go-jose.v2" -) - -func TestFindKeyByPrefix(t *testing.T) { - jwks := &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{ - {KeyID: "public:foo"}, - {KeyID: "private:foo"}, - }} - - key, err := FindKeyByPrefix(jwks, "public") - require.NoError(t, err) - assert.Equal(t, "public:foo", key.KeyID) + "github.com/golang/mock/gomock" + "github.com/pborman/uuid" + "github.com/pkg/errors" - key, err = FindKeyByPrefix(jwks, "private") - require.NoError(t, err) - assert.Equal(t, "private:foo", key.KeyID) + "github.com/ory/hydra/internal" + "github.com/ory/hydra/jwk" + "github.com/ory/hydra/x" - _, err = FindKeyByPrefix(jwks, "asdf") - require.Error(t, err) + "gopkg.in/square/go-jose.v2/cryptosigner" - jwks = &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{ - {KeyID: "public:"}, - {KeyID: "private:"}, - }} + "gopkg.in/square/go-jose.v2" - key, err = FindKeyByPrefix(jwks, "public") - require.NoError(t, err) - assert.Equal(t, "public:", key.KeyID) - - key, err = FindKeyByPrefix(jwks, "private") - require.NoError(t, err) - assert.Equal(t, "private:", key.KeyID) - - _, err = FindKeyByPrefix(jwks, "asdf") - require.Error(t, err) - - jwks = &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{ - {KeyID: ""}, - }} - require.Error(t, err) -} + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) func TestIder(t *testing.T) { - assert.True(t, len(Ider("public", "")) > len("public:")) - assert.Equal(t, "public:foo", Ider("public", "foo")) + assert.True(t, len(jwk.Ider("public", "")) > len("public:")) + assert.Equal(t, "public:foo", jwk.Ider("public", "foo")) } func TestHandlerFindPublicKey(t *testing.T) { - var testRSGenerator = RS256Generator{} - var testECDSAGenerator = ECDSA256Generator{} - var testEdDSAGenerator = EdDSAGenerator{} + var testRSGenerator = jwk.RS256Generator{} + var testECDSAGenerator = jwk.ECDSA256Generator{} + var testEdDSAGenerator = jwk.EdDSAGenerator{} t.Run("Test_Helper/Run_FindPublicKey_With_RSA", func(t *testing.T) { RSIDKS, _ := testRSGenerator.Generate("test-id-1", "sig") - keys, err := FindPublicKey(RSIDKS) + keys, err := jwk.FindPublicKey(RSIDKS) require.NoError(t, err) - assert.Equal(t, keys.KeyID, Ider("public", "test-id-1")) + assert.Equal(t, keys.KeyID, jwk.Ider("public", "test-id-1")) assert.IsType(t, keys.Key, new(rsa.PublicKey)) }) t.Run("Test_Helper/Run_FindPublicKey_With_ECDSA", func(t *testing.T) { ECDSAIDKS, _ := testECDSAGenerator.Generate("test-id-2", "sig") - keys, err := FindPublicKey(ECDSAIDKS) + keys, err := jwk.FindPublicKey(ECDSAIDKS) require.NoError(t, err) - assert.Equal(t, keys.KeyID, Ider("public", "test-id-2")) + assert.Equal(t, keys.KeyID, jwk.Ider("public", "test-id-2")) assert.IsType(t, keys.Key, new(ecdsa.PublicKey)) }) t.Run("Test_Helper/Run_FindPublicKey_With_EdDSA", func(t *testing.T) { EdDSAIDKS, _ := testEdDSAGenerator.Generate("test-id-3", "sig") - keys, err := FindPublicKey(EdDSAIDKS) + keys, err := jwk.FindPublicKey(EdDSAIDKS) require.NoError(t, err) - assert.Equal(t, keys.KeyID, Ider("public", "test-id-3")) + assert.Equal(t, keys.KeyID, jwk.Ider("public", "test-id-3")) assert.IsType(t, keys.Key, ed25519.PublicKey{}) }) + + t.Run("Test_Helper/Run_FindPublicKey_With_KeyNotFound", func(t *testing.T) { + keySet := &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{}} + _, err := jwk.FindPublicKey(keySet) + require.Error(t, err) + assert.True(t, strings.Contains(err.Error(), "key not found")) + }) +} + +func TestHandlerFindPrivateKey(t *testing.T) { + var testRSGenerator = jwk.RS256Generator{} + var testECDSAGenerator = jwk.ECDSA256Generator{} + var testEdDSAGenerator = jwk.EdDSAGenerator{} + + t.Run("Test_Helper/Run_FindPrivateKey_With_RSA", func(t *testing.T) { + RSIDKS, _ := testRSGenerator.Generate("test-id-1", "sig") + keys, err := jwk.FindPrivateKey(RSIDKS) + require.NoError(t, err) + assert.Equal(t, keys.KeyID, jwk.Ider("private", "test-id-1")) + assert.IsType(t, keys.Key, new(rsa.PrivateKey)) + }) + + t.Run("Test_Helper/Run_FindPrivateKey_With_ECDSA", func(t *testing.T) { + ECDSAIDKS, _ := testECDSAGenerator.Generate("test-id-2", "sig") + keys, err := jwk.FindPrivateKey(ECDSAIDKS) + require.NoError(t, err) + assert.Equal(t, keys.KeyID, jwk.Ider("private", "test-id-2")) + assert.IsType(t, keys.Key, new(ecdsa.PrivateKey)) + }) + + t.Run("Test_Helper/Run_FindPrivateKey_With_EdDSA", func(t *testing.T) { + EdDSAIDKS, _ := testEdDSAGenerator.Generate("test-id-3", "sig") + keys, err := jwk.FindPrivateKey(EdDSAIDKS) + require.NoError(t, err) + assert.Equal(t, keys.KeyID, jwk.Ider("private", "test-id-3")) + assert.IsType(t, keys.Key, ed25519.PrivateKey{}) + }) + + t.Run("Test_Helper/Run_FindPrivateKey_With_KeyNotFound", func(t *testing.T) { + keySet := &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{}} + _, err := jwk.FindPublicKey(keySet) + require.Error(t, err) + assert.True(t, strings.Contains(err.Error(), "key not found")) + }) +} + +func TestPEMBlockForKey(t *testing.T) { + var testRSGenerator = jwk.RS256Generator{} + var testECDSAGenerator = jwk.ECDSA256Generator{} + var testEdDSAGenerator = jwk.EdDSAGenerator{} + + t.Run("Test_Helper/Run_PEMBlockForKey_With_RSA", func(t *testing.T) { + RSIDKS, _ := testRSGenerator.Generate("test-id-1", "sig") + key, err := jwk.FindPrivateKey(RSIDKS) + require.NoError(t, err) + pemBlock, err := jwk.PEMBlockForKey(key.Key) + require.NoError(t, err) + assert.IsType(t, pem.Block{}, *pemBlock) + assert.Equal(t, "RSA PRIVATE KEY", pemBlock.Type) + }) + + t.Run("Test_Helper/Run_PEMBlockForKey_With_ECDSA", func(t *testing.T) { + ECDSAIDKS, _ := testECDSAGenerator.Generate("test-id-2", "sig") + key, err := jwk.FindPrivateKey(ECDSAIDKS) + require.NoError(t, err) + pemBlock, err := jwk.PEMBlockForKey(key.Key) + require.NoError(t, err) + assert.IsType(t, pem.Block{}, *pemBlock) + assert.Equal(t, "EC PRIVATE KEY", pemBlock.Type) + }) + + t.Run("Test_Helper/Run_PEMBlockForKey_With_EdDSA", func(t *testing.T) { + EdDSAIDKS, _ := testEdDSAGenerator.Generate("test-id-3", "sig") + key, err := jwk.FindPrivateKey(EdDSAIDKS) + require.NoError(t, err) + pemBlock, err := jwk.PEMBlockForKey(key.Key) + require.NoError(t, err) + assert.IsType(t, pem.Block{}, *pemBlock) + assert.Equal(t, "PRIVATE KEY", pemBlock.Type) + }) + + t.Run("Test_Helper/Run_PEMBlockForKey_With_InvalidKeyType", func(t *testing.T) { + key := dsa.PrivateKey{} + _, err := jwk.PEMBlockForKey(key) + require.Error(t, err) + assert.True(t, strings.Contains(err.Error(), "Invalid key type")) + }) +} + +func TestExcludeOpaquePrivateKeys(t *testing.T) { + var testRSGenerator = jwk.RS256Generator{} + + opaqueKeys, err := testRSGenerator.Generate("test-id-1", "sig") + assert.NoError(t, err) + assert.Len(t, opaqueKeys.Keys, 2) + opaqueKeys.Keys[0].Key = cryptosigner.Opaque(opaqueKeys.Keys[0].Key.(*rsa.PrivateKey)) + keys := jwk.ExcludeOpaquePrivateKeys(opaqueKeys) + assert.Len(t, keys.Keys, 1) + assert.IsType(t, new(rsa.PublicKey), keys.Keys[0].Key) +} + +func TestGetOrGenerateKeys(t *testing.T) { + reg := internal.NewMockedRegistry(t) + ctrl := gomock.NewController(t) + keyManager := NewMockManager(ctrl) + defer ctrl.Finish() + + setId := uuid.NewUUID().String() + keyId := uuid.NewUUID().String() + + generator := reg.KeyGenerators()["RS256"] + keySet, _ := generator.Generate(keyId, "sig") + keySetWithoutPublicKey := &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{keySet.Keys[0]}, + } + keySetWithoutPrivateKey := &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{keySet.Keys[1]}, + } + + t.Run("Test_Helper/Run_GetOrGenerateKeys_With_GetKeySetError", func(t *testing.T) { + keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(nil, errors.New("GetKeySetError")) + pubKey, privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256") + assert.Nil(t, pubKey) + assert.Nil(t, privKey) + assert.EqualError(t, err, "GetKeySetError") + }) + + t.Run("Test_Helper/Run_GetOrGenerateKeys_With_GenerateAndPersistKeySetError", func(t *testing.T) { + keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(nil, errors.Wrap(x.ErrNotFound, "")) + keyManager.EXPECT().GenerateAndPersistKeySet(gomock.Any(), gomock.Eq(setId), gomock.Eq(keyId), gomock.Eq("RS256"), gomock.Eq("sig")).Return(nil, errors.New("GetKeySetError")) + pubKey, privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256") + assert.Nil(t, pubKey) + assert.Nil(t, privKey) + assert.EqualError(t, err, "GetKeySetError") + }) + + t.Run("Test_Helper/Run_GetOrGenerateKeys_With_GenerateAndPersistKeySetError", func(t *testing.T) { + keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(keySetWithoutPrivateKey, nil) + keyManager.EXPECT().GenerateAndPersistKeySet(gomock.Any(), gomock.Eq(setId), gomock.Eq(keyId), gomock.Eq("RS256"), gomock.Eq("sig")).Return(nil, errors.New("GetKeySetError")) + pubKey, privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256") + assert.Nil(t, pubKey) + assert.Nil(t, privKey) + assert.EqualError(t, err, "GetKeySetError") + }) + + t.Run("Test_Helper/Run_GetOrGenerateKeys_With_GetKeySet_ContainsMissingPublicKey", func(t *testing.T) { + keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(keySetWithoutPublicKey, nil) + keyManager.EXPECT().GenerateAndPersistKeySet(gomock.Any(), gomock.Eq(setId), gomock.Eq(keyId), gomock.Eq("RS256"), gomock.Eq("sig")).Return(keySet, nil) + pubKey, privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256") + assert.NoError(t, err) + assert.Equal(t, privKey, &keySet.Keys[0]) + assert.Equal(t, pubKey, &keySet.Keys[1]) + }) + + t.Run("Test_Helper/Run_GetOrGenerateKeys_With_GetKeySet_ContainsMissingPrivateKey", func(t *testing.T) { + keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(keySetWithoutPrivateKey, nil) + keyManager.EXPECT().GenerateAndPersistKeySet(gomock.Any(), gomock.Eq(setId), gomock.Eq(keyId), gomock.Eq("RS256"), gomock.Eq("sig")).Return(keySet, nil) + pubKey, privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256") + assert.NoError(t, err) + assert.Equal(t, privKey, &keySet.Keys[0]) + assert.Equal(t, pubKey, &keySet.Keys[1]) + }) + + t.Run("Test_Helper/Run_GetOrGenerateKeys_With_GenerateAndPersistKeySet_ContainsMissingPublicKey", func(t *testing.T) { + keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(keySetWithoutPrivateKey, nil) + keyManager.EXPECT().GenerateAndPersistKeySet(gomock.Any(), gomock.Eq(setId), gomock.Eq(keyId), gomock.Eq("RS256"), gomock.Eq("sig")).Return(keySetWithoutPublicKey, nil) + pubKey, privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256") + assert.Nil(t, pubKey) + assert.Nil(t, privKey) + assert.EqualError(t, err, "key not found") + }) + + t.Run("Test_Helper/Run_GetOrGenerateKeys_With_GenerateAndPersistKeySet_ContainsMissingPrivateKey", func(t *testing.T) { + keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(keySetWithoutPublicKey, nil) + keyManager.EXPECT().GenerateAndPersistKeySet(gomock.Any(), gomock.Eq(setId), gomock.Eq(keyId), gomock.Eq("RS256"), gomock.Eq("sig")).Return(keySetWithoutPrivateKey, nil) + pubKey, privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256") + assert.Nil(t, pubKey) + assert.Nil(t, privKey) + assert.EqualError(t, err, "key not found") + }) } diff --git a/jwk/jwt_strategy.go b/jwk/jwt_strategy.go index 3c227330db5..7cd9ed9fc09 100644 --- a/jwk/jwt_strategy.go +++ b/jwk/jwt_strategy.go @@ -26,6 +26,8 @@ import ( "strings" "sync" + "gopkg.in/square/go-jose.v2" + "github.com/ory/hydra/driver/config" "github.com/pkg/errors" @@ -50,13 +52,13 @@ type RS256JWTStrategy struct { rs func() string publicKey *rsa.PublicKey - privateKey *rsa.PrivateKey + privateKey interface{} publicKeyID string privateKeyID string } -func NewRS256JWTStrategy(r InternalRegistry, rs func() string) (*RS256JWTStrategy, error) { - j := &RS256JWTStrategy{r: r, rs: rs, RS256JWTStrategy: new(jwt.RS256JWTStrategy)} +func NewRS256JWTStrategy(c config.Provider, r InternalRegistry, rs func() string) (*RS256JWTStrategy, error) { + j := &RS256JWTStrategy{c: &c, r: r, rs: rs, RS256JWTStrategy: new(jwt.RS256JWTStrategy)} if err := j.refresh(context.TODO()); err != nil { return nil, err } @@ -114,12 +116,12 @@ func (j *RS256JWTStrategy) refresh(ctx context.Context) error { return err } - public, err := FindKeyByPrefix(keys, "public") + public, err := FindPublicKey(keys) if err != nil { return err } - private, err := FindKeyByPrefix(keys, "private") + private, err := FindPrivateKey(keys) if err != nil { return err } @@ -128,15 +130,6 @@ func (j *RS256JWTStrategy) refresh(ctx context.Context) error { return errors.New("public and private key pair kids do not match") } - if k, ok := private.Key.(*rsa.PrivateKey); !ok { - return errors.New("unable to type assert key to *rsa.PrivateKey") - } else { - j.Lock() - j.privateKey = k - j.RS256JWTStrategy.PrivateKey = k - j.Unlock() - } - if k, ok := public.Key.(*rsa.PublicKey); !ok { return errors.New("unable to type assert key to *rsa.PublicKey") } else { @@ -146,12 +139,25 @@ func (j *RS256JWTStrategy) refresh(ctx context.Context) error { j.Unlock() } - j.RLock() - defer j.RUnlock() - if j.privateKey.PublicKey.E != j.publicKey.E || - j.privateKey.PublicKey.N.String() != j.publicKey.N.String() { - return errors.New("public and private key pair fetched from store does not match") - } + if k, ok := private.Key.(*rsa.PrivateKey); ok { + j.Lock() + j.privateKey = k + j.RS256JWTStrategy.PrivateKey = k + j.Unlock() + j.RLock() + defer j.RUnlock() + if k.PublicKey.E != j.publicKey.E || + k.PublicKey.N.String() != j.publicKey.N.String() { + return errors.New("public and private key pair fetched from store does not match") + } + } else if k, ok := private.Key.(jose.OpaqueSigner); ok { + j.Lock() + j.privateKey = k + j.RS256JWTStrategy.PrivateKey = k + j.Unlock() + } else { + return errors.New("unknown private key type") + } return nil } diff --git a/jwk/jwt_strategy_test.go b/jwk/jwt_strategy_test.go index d6c0161b1cc..4745afed43b 100644 --- a/jwk/jwt_strategy_test.go +++ b/jwk/jwt_strategy_test.go @@ -22,8 +22,16 @@ package jwk_test import ( "context" + "crypto/rsa" + "crypto/x509" + "errors" "testing" + "github.com/golang/mock/gomock" + "github.com/pborman/uuid" + "gopkg.in/square/go-jose.v2" + "gopkg.in/square/go-jose.v2/cryptosigner" + "github.com/ory/hydra/internal" "github.com/stretchr/testify/assert" @@ -38,15 +46,12 @@ import ( func TestRS256JWTStrategy(t *testing.T) { conf := internal.NewConfigurationWithDefaults() reg := internal.NewRegistryMemory(t, conf) - - testGenerator := &RS256Generator{} - m := reg.KeyManager() - ks, err := testGenerator.Generate("foo", "sig") + + _, err := m.GenerateAndPersistKeySet(context.TODO(), "foo-set", "foo", "RS256", "sig") require.NoError(t, err) - require.NoError(t, m.AddKeySet(context.TODO(), "foo-set", ks)) - s, err := NewRS256JWTStrategy(reg, func() string { + s, err := NewRS256JWTStrategy(*conf, reg, func() string { return "foo-set" }) @@ -59,13 +64,11 @@ func TestRS256JWTStrategy(t *testing.T) { _, err = s.Validate(context.TODO(), a) require.NoError(t, err) - kid, err := s.GetPublicKeyID(context.TODO()) + kidFoo, err := s.GetPublicKeyID(context.TODO()) assert.NoError(t, err) - assert.Equal(t, "public:foo", kid) - ks, err = testGenerator.Generate("bar", "sig") + _, err = m.GenerateAndPersistKeySet(context.TODO(), "foo-set", "bar", "RS256", "sig") require.NoError(t, err) - require.NoError(t, m.AddKeySet(context.TODO(), "foo-set", ks)) a, b, err = s.Generate(context.TODO(), jwt2.MapClaims{"foo": "bar"}, &jwt.Headers{}) require.NoError(t, err) @@ -75,7 +78,124 @@ func TestRS256JWTStrategy(t *testing.T) { _, err = s.Validate(context.TODO(), a) require.NoError(t, err) - kid, err = s.GetPublicKeyID(context.TODO()) + kidBar, err := s.GetPublicKeyID(context.TODO()) assert.NoError(t, err) - assert.Equal(t, "public:bar", kid) + + if conf.HsmEnabled() { + assert.Equal(t, "foo", kidFoo) + assert.Equal(t, "bar", kidBar) + } else { + assert.Equal(t, "public:foo", kidFoo) + assert.Equal(t, "public:bar", kidBar) + } +} + +func TestRS256JWTStrategy_Refresh(t *testing.T) { + conf := internal.NewConfigurationWithDefaults() + ctrl := gomock.NewController(t) + keyManager := NewMockManager(ctrl) + reg := NewMockInternalRegistry(ctrl) + defer ctrl.Finish() + + reg.EXPECT().KeyManager().Return(keyManager).AnyTimes() + + setId := uuid.NewUUID().String() + keyId := uuid.NewUUID().String() + + rsaGenerator := &RS256Generator{KeyLength: 1024} + rsaKeySet, err := rsaGenerator.Generate(keyId, "sig") + require.NoError(t, err) + edsaGenerator := &ECDSA256Generator{} + edsaKeySet, err := edsaGenerator.Generate(keyId, "sig") + require.NoError(t, err) + + t.Run("With_RsaKeyPair", func(t *testing.T) { + keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(rsaKeySet, nil) + strategy, err := NewRS256JWTStrategy(*conf, reg, func() string { + return setId + }) + require.NoError(t, err) + require.IsType(t, new(rsa.PrivateKey), strategy.RS256JWTStrategy.PrivateKey) + }) + + t.Run("With_OpaqueKeyPair", func(t *testing.T) { + opaquePrivateKey := cryptosigner.Opaque(rsaKeySet.Keys[0].Key.(*rsa.PrivateKey)) + keySetWithOpaquePrivateKey := &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{{ + Algorithm: "RS256", + Use: "sig", + Key: opaquePrivateKey, + KeyID: keyId, + Certificates: []*x509.Certificate{}, + CertificateThumbprintSHA1: []uint8{}, + CertificateThumbprintSHA256: []uint8{}, + }, rsaKeySet.Keys[1]}, + } + + keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(keySetWithOpaquePrivateKey, nil) + strategy, err := NewRS256JWTStrategy(*conf, reg, func() string { + return setId + }) + require.NoError(t, err) + require.IsType(t, opaquePrivateKey, strategy.RS256JWTStrategy.PrivateKey) + }) + + t.Run("With_GetKeySetError", func(t *testing.T) { + keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(nil, errors.New("GetKeySetError")) + _, err := NewRS256JWTStrategy(*conf, reg, func() string { + return setId + }) + require.EqualError(t, err, "GetKeySetError") + }) + + t.Run("With_FindPublicKeyError", func(t *testing.T) { + keySetWithoutPublicKey := &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{rsaKeySet.Keys[0]}, + } + keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(keySetWithoutPublicKey, nil) + _, err := NewRS256JWTStrategy(*conf, reg, func() string { + return setId + }) + require.EqualError(t, err, "key not found") + }) + + t.Run("With_FindPrivateKeyError", func(t *testing.T) { + keySetWithoutPrivateKey := &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{rsaKeySet.Keys[1]}, + } + keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(keySetWithoutPrivateKey, nil) + _, err := NewRS256JWTStrategy(*conf, reg, func() string { + return setId + }) + require.EqualError(t, err, "key not found") + }) + + t.Run("With_PublicKeyTypeError", func(t *testing.T) { + keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(edsaKeySet, nil) + _, err := NewRS256JWTStrategy(*conf, reg, func() string { + return setId + }) + require.EqualError(t, err, "unable to type assert key to *rsa.PublicKey") + }) + + t.Run("With_PrivateKeyTypeError", func(t *testing.T) { + keyInvalidPrivateKeyType := &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{edsaKeySet.Keys[0], rsaKeySet.Keys[1]}, + } + keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(keyInvalidPrivateKeyType, nil) + _, err := NewRS256JWTStrategy(*conf, reg, func() string { + return setId + }) + require.EqualError(t, err, "unknown private key type") + }) + + t.Run("With_KeyPairIdsNotMatchError", func(t *testing.T) { + rsaKeySet.Keys[0].KeyID = uuid.NewUUID().String() + keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(rsaKeySet, nil) + _, err := NewRS256JWTStrategy(*conf, reg, func() string { + return setId + }) + require.EqualError(t, err, "public and private key pair kids do not match") + rsaKeySet.Keys[0].KeyID = rsaKeySet.Keys[1].KeyID + }) } diff --git a/jwk/manager.go b/jwk/manager.go index e2fe33d5768..e2c5cbfae53 100644 --- a/jwk/manager.go +++ b/jwk/manager.go @@ -22,17 +22,38 @@ package jwk import ( "context" + "net/http" "time" + "github.com/ory/fosite" + jose "gopkg.in/square/go-jose.v2" ) +var ErrUnsupportedKeyAlgorithm = &fosite.RFC6749Error{ + CodeField: http.StatusBadRequest, + ErrorField: http.StatusText(http.StatusBadRequest), + DescriptionField: "Unsupported key algorithm", +} + +var ErrUnsupportedEllipticCurve = &fosite.RFC6749Error{ + CodeField: http.StatusBadRequest, + ErrorField: http.StatusText(http.StatusBadRequest), + DescriptionField: "Unsupported elliptic curve", +} + type ( Manager interface { + GenerateAndPersistKeySet(ctx context.Context, set, kid, alg, use string) (*jose.JSONWebKeySet, error) + AddKey(ctx context.Context, set string, key *jose.JSONWebKey) error AddKeySet(ctx context.Context, set string, keys *jose.JSONWebKeySet) error + UpdateKey(ctx context.Context, set string, key *jose.JSONWebKey) error + + UpdateKeySet(ctx context.Context, set string, keys *jose.JSONWebKeySet) error + GetKey(ctx context.Context, set, kid string) (*jose.JSONWebKeySet, error) GetKeySet(ctx context.Context, set string) (*jose.JSONWebKeySet, error) diff --git a/jwk/manager_mock_test.go b/jwk/manager_mock_test.go new file mode 100644 index 00000000000..b91641dedee --- /dev/null +++ b/jwk/manager_mock_test.go @@ -0,0 +1,165 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: jwk/manager.go + +// Package mock_jwk is a generated GoMock package. +package jwk_test + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + jose "gopkg.in/square/go-jose.v2" +) + +// MockManager is a mock of Manager interface. +type MockManager struct { + ctrl *gomock.Controller + recorder *MockManagerMockRecorder +} + +// MockManagerMockRecorder is the mock recorder for MockManager. +type MockManagerMockRecorder struct { + mock *MockManager +} + +// NewMockManager creates a new mock instance. +func NewMockManager(ctrl *gomock.Controller) *MockManager { + mock := &MockManager{ctrl: ctrl} + mock.recorder = &MockManagerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockManager) EXPECT() *MockManagerMockRecorder { + return m.recorder +} + +// AddKey mocks base method. +func (m *MockManager) AddKey(ctx context.Context, set string, key *jose.JSONWebKey) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddKey", ctx, set, key) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddKey indicates an expected call of AddKey. +func (mr *MockManagerMockRecorder) AddKey(ctx, set, key interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddKey", reflect.TypeOf((*MockManager)(nil).AddKey), ctx, set, key) +} + +// AddKeySet mocks base method. +func (m *MockManager) AddKeySet(ctx context.Context, set string, keys *jose.JSONWebKeySet) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddKeySet", ctx, set, keys) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddKeySet indicates an expected call of AddKeySet. +func (mr *MockManagerMockRecorder) AddKeySet(ctx, set, keys interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddKeySet", reflect.TypeOf((*MockManager)(nil).AddKeySet), ctx, set, keys) +} + +// DeleteKey mocks base method. +func (m *MockManager) DeleteKey(ctx context.Context, set, kid string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteKey", ctx, set, kid) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteKey indicates an expected call of DeleteKey. +func (mr *MockManagerMockRecorder) DeleteKey(ctx, set, kid interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteKey", reflect.TypeOf((*MockManager)(nil).DeleteKey), ctx, set, kid) +} + +// DeleteKeySet mocks base method. +func (m *MockManager) DeleteKeySet(ctx context.Context, set string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteKeySet", ctx, set) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteKeySet indicates an expected call of DeleteKeySet. +func (mr *MockManagerMockRecorder) DeleteKeySet(ctx, set interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteKeySet", reflect.TypeOf((*MockManager)(nil).DeleteKeySet), ctx, set) +} + +// GenerateAndPersistKeySet mocks base method. +func (m *MockManager) GenerateAndPersistKeySet(ctx context.Context, set, kid, alg, use string) (*jose.JSONWebKeySet, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GenerateAndPersistKeySet", ctx, set, kid, alg, use) + ret0, _ := ret[0].(*jose.JSONWebKeySet) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GenerateAndPersistKeySet indicates an expected call of GenerateAndPersistKeySet. +func (mr *MockManagerMockRecorder) GenerateAndPersistKeySet(ctx, set, kid, alg, use interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GenerateAndPersistKeySet", reflect.TypeOf((*MockManager)(nil).GenerateAndPersistKeySet), ctx, set, kid, alg, use) +} + +// GetKey mocks base method. +func (m *MockManager) GetKey(ctx context.Context, set, kid string) (*jose.JSONWebKeySet, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetKey", ctx, set, kid) + ret0, _ := ret[0].(*jose.JSONWebKeySet) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetKey indicates an expected call of GetKey. +func (mr *MockManagerMockRecorder) GetKey(ctx, set, kid interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKey", reflect.TypeOf((*MockManager)(nil).GetKey), ctx, set, kid) +} + +// GetKeySet mocks base method. +func (m *MockManager) GetKeySet(ctx context.Context, set string) (*jose.JSONWebKeySet, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetKeySet", ctx, set) + ret0, _ := ret[0].(*jose.JSONWebKeySet) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetKeySet indicates an expected call of GetKeySet. +func (mr *MockManagerMockRecorder) GetKeySet(ctx, set interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKeySet", reflect.TypeOf((*MockManager)(nil).GetKeySet), ctx, set) +} + +// UpdateKey mocks base method. +func (m *MockManager) UpdateKey(ctx context.Context, set string, key *jose.JSONWebKey) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateKey", ctx, set, key) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateKey indicates an expected call of UpdateKey. +func (mr *MockManagerMockRecorder) UpdateKey(ctx, set, key interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateKey", reflect.TypeOf((*MockManager)(nil).UpdateKey), ctx, set, key) +} + +// UpdateKeySet mocks base method. +func (m *MockManager) UpdateKeySet(ctx context.Context, set string, keys *jose.JSONWebKeySet) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateKeySet", ctx, set, keys) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateKeySet indicates an expected call of UpdateKeySet. +func (mr *MockManagerMockRecorder) UpdateKeySet(ctx, set, keys interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateKeySet", reflect.TypeOf((*MockManager)(nil).UpdateKeySet), ctx, set, keys) +} diff --git a/jwk/manager_strategy.go b/jwk/manager_strategy.go new file mode 100644 index 00000000000..2e7160181c7 --- /dev/null +++ b/jwk/manager_strategy.go @@ -0,0 +1,86 @@ +package jwk + +import ( + "context" + + "github.com/pkg/errors" + "gopkg.in/square/go-jose.v2" + + "github.com/ory/hydra/x" +) + +type ManagerStrategy struct { + hardwareKeyManager Manager + softwareKeyManager Manager +} + +func NewManagerStrategy(hardwareKeyManager Manager, softwareKeyManager Manager) *ManagerStrategy { + return &ManagerStrategy{ + hardwareKeyManager: hardwareKeyManager, + softwareKeyManager: softwareKeyManager, + } +} + +func (m ManagerStrategy) GenerateAndPersistKeySet(ctx context.Context, set, kid, alg, use string) (*jose.JSONWebKeySet, error) { + return m.hardwareKeyManager.GenerateAndPersistKeySet(ctx, set, kid, alg, use) +} + +func (m ManagerStrategy) AddKey(ctx context.Context, set string, key *jose.JSONWebKey) error { + return m.softwareKeyManager.AddKey(ctx, set, key) +} + +func (m ManagerStrategy) AddKeySet(ctx context.Context, set string, keys *jose.JSONWebKeySet) error { + return m.softwareKeyManager.AddKeySet(ctx, set, keys) +} + +func (m ManagerStrategy) UpdateKey(ctx context.Context, set string, key *jose.JSONWebKey) error { + return m.softwareKeyManager.UpdateKey(ctx, set, key) +} + +func (m ManagerStrategy) UpdateKeySet(ctx context.Context, set string, keys *jose.JSONWebKeySet) error { + return m.softwareKeyManager.UpdateKeySet(ctx, set, keys) +} + +func (m ManagerStrategy) GetKey(ctx context.Context, set, kid string) (*jose.JSONWebKeySet, error) { + keySet, err := m.hardwareKeyManager.GetKey(ctx, set, kid) + if err != nil && !errors.Is(err, x.ErrNotFound) { + return nil, err + } else if keySet != nil { + return keySet, nil + } else { + return m.softwareKeyManager.GetKey(ctx, set, kid) + } +} + +func (m ManagerStrategy) GetKeySet(ctx context.Context, set string) (*jose.JSONWebKeySet, error) { + keySet, err := m.hardwareKeyManager.GetKeySet(ctx, set) + if err != nil && !errors.Is(err, x.ErrNotFound) { + return nil, err + } else if keySet != nil { + return keySet, nil + } else { + return m.softwareKeyManager.GetKeySet(ctx, set) + } +} + +func (m ManagerStrategy) DeleteKey(ctx context.Context, set, kid string) error { + err := m.hardwareKeyManager.DeleteKey(ctx, set, kid) + if err != nil && !errors.Is(err, x.ErrNotFound) { + return err + } else if errors.Is(err, x.ErrNotFound) { + return m.softwareKeyManager.DeleteKey(ctx, set, kid) + } else { + return nil + } +} + +func (m ManagerStrategy) DeleteKeySet(ctx context.Context, set string) error { + err := m.hardwareKeyManager.DeleteKeySet(ctx, set) + if err != nil && !errors.Is(err, x.ErrNotFound) { + return err + } else if errors.Is(err, x.ErrNotFound) { + return m.softwareKeyManager.DeleteKeySet(ctx, set) + } else { + return nil + } +} diff --git a/jwk/manager_strategy_test.go b/jwk/manager_strategy_test.go new file mode 100644 index 00000000000..ce255e9fab3 --- /dev/null +++ b/jwk/manager_strategy_test.go @@ -0,0 +1,205 @@ +package jwk_test + +import ( + "testing" + + "github.com/golang/mock/gomock" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "gopkg.in/square/go-jose.v2" + + "github.com/ory/hydra/jwk" + "github.com/ory/hydra/x" +) + +func TestKeyManagerStrategy(t *testing.T) { + ctrl := gomock.NewController(t) + softwareKeyManager := NewMockManager(ctrl) + hardwareKeyManager := NewMockManager(ctrl) + keyManager := jwk.NewManagerStrategy(hardwareKeyManager, softwareKeyManager) + defer ctrl.Finish() + hwKeySet := &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{{ + KeyID: "hwKeyID", + }}, + } + swKeySet := &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{{ + KeyID: "swKeyID", + }}, + } + + t.Run("GenerateAndPersistKeySet_WithResult", func(t *testing.T) { + hardwareKeyManager.EXPECT().GenerateAndPersistKeySet(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1"), gomock.Any(), gomock.Any()).Return(hwKeySet, nil) + resultKeySet, err := keyManager.GenerateAndPersistKeySet(nil, "set1", "kid1", "RS256", "sig") + assert.NoError(t, err) + assert.Equal(t, hwKeySet, resultKeySet) + }) + + t.Run("GenerateAndPersistKeySet_WithError", func(t *testing.T) { + hardwareKeyManager.EXPECT().GenerateAndPersistKeySet(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1"), gomock.Any(), gomock.Any()).Return(nil, errors.New("test")) + resultKeySet, err := keyManager.GenerateAndPersistKeySet(nil, "set1", "kid1", "RS256", "sig") + assert.Error(t, err, "test") + assert.Nil(t, resultKeySet) + }) + + t.Run("AddKey", func(t *testing.T) { + softwareKeyManager.EXPECT().AddKey(gomock.Any(), gomock.Eq("set1"), gomock.Any()).Return(nil) + err := keyManager.AddKey(nil, "set1", nil) + assert.NoError(t, err) + }) + + t.Run("AddKey_WithError", func(t *testing.T) { + softwareKeyManager.EXPECT().AddKey(gomock.Any(), gomock.Eq("set1"), gomock.Any()).Return(errors.New("test")) + err := keyManager.AddKey(nil, "set1", nil) + assert.Error(t, err, "test") + }) + + t.Run("AddKeySet", func(t *testing.T) { + softwareKeyManager.EXPECT().AddKeySet(gomock.Any(), gomock.Eq("set1"), gomock.Any()).Return(nil) + err := keyManager.AddKeySet(nil, "set1", nil) + assert.NoError(t, err) + }) + + t.Run("AddKeySet_WithError", func(t *testing.T) { + softwareKeyManager.EXPECT().AddKeySet(gomock.Any(), gomock.Eq("set1"), gomock.Any()).Return(errors.New("test")) + err := keyManager.AddKeySet(nil, "set1", nil) + assert.Error(t, err, "test") + }) + + t.Run("UpdateKey", func(t *testing.T) { + softwareKeyManager.EXPECT().UpdateKey(gomock.Any(), gomock.Eq("set1"), gomock.Any()).Return(nil) + err := keyManager.UpdateKey(nil, "set1", nil) + assert.NoError(t, err) + }) + + t.Run("UpdateKey_WithError", func(t *testing.T) { + softwareKeyManager.EXPECT().UpdateKey(gomock.Any(), gomock.Eq("set1"), gomock.Any()).Return(errors.New("test")) + err := keyManager.UpdateKey(nil, "set1", nil) + assert.Error(t, err, "test") + }) + + t.Run("UpdateKeySet", func(t *testing.T) { + softwareKeyManager.EXPECT().UpdateKeySet(gomock.Any(), gomock.Eq("set1"), gomock.Any()).Return(nil) + err := keyManager.UpdateKeySet(nil, "set1", nil) + assert.NoError(t, err) + }) + + t.Run("UpdateKeySet_WithError", func(t *testing.T) { + softwareKeyManager.EXPECT().UpdateKeySet(gomock.Any(), gomock.Eq("set1"), gomock.Any()).Return(errors.New("test")) + err := keyManager.UpdateKeySet(nil, "set1", nil) + assert.Error(t, err, "test") + }) + + t.Run("GetKey_WithResultFromHardwareKeyManager", func(t *testing.T) { + hardwareKeyManager.EXPECT().GetKey(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1")).Return(hwKeySet, nil) + resultKeySet, err := keyManager.GetKey(nil, "set1", "kid1") + assert.NoError(t, err) + assert.Equal(t, hwKeySet, resultKeySet) + }) + + t.Run("GetKey_WithErrorFromHardwareKeyManager", func(t *testing.T) { + hardwareKeyManager.EXPECT().GetKey(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1")).Return(nil, errors.New("test")) + resultKeySet, err := keyManager.GetKey(nil, "set1", "kid1") + assert.Error(t, err, "test") + assert.Nil(t, resultKeySet) + }) + + t.Run("GetKey_WithErrNotFoundFromHardwareKeyManager", func(t *testing.T) { + hardwareKeyManager.EXPECT().GetKey(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1")).Return(nil, errors.WithStack(x.ErrNotFound)) + softwareKeyManager.EXPECT().GetKey(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1")).Return(swKeySet, nil) + resultKeySet, err := keyManager.GetKey(nil, "set1", "kid1") + assert.NoError(t, err) + assert.Equal(t, swKeySet, resultKeySet) + }) + + t.Run("GetKey_WithErrNotFoundFromSoftwareKeyManager", func(t *testing.T) { + hardwareKeyManager.EXPECT().GetKey(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1")).Return(nil, errors.WithStack(x.ErrNotFound)) + softwareKeyManager.EXPECT().GetKey(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1")).Return(nil, errors.WithStack(x.ErrNotFound)) + resultKeySet, err := keyManager.GetKey(nil, "set1", "kid1") + assert.Error(t, err, "Not Found") + assert.Nil(t, resultKeySet) + }) + + t.Run("GetKeySet_WithResultFromHardwareKeyManager", func(t *testing.T) { + hardwareKeyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq("set1")).Return(hwKeySet, nil) + resultKeySet, err := keyManager.GetKeySet(nil, "set1") + assert.NoError(t, err) + assert.Equal(t, hwKeySet, resultKeySet) + }) + + t.Run("GetKeySet_WithErrorFromHardwareKeyManager", func(t *testing.T) { + hardwareKeyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq("set1")).Return(nil, errors.New("test")) + resultKeySet, err := keyManager.GetKeySet(nil, "set1") + assert.Error(t, err, "test") + assert.Nil(t, resultKeySet) + }) + + t.Run("GetKeySet_WithErrNotFoundFromHardwareKeyManager", func(t *testing.T) { + hardwareKeyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq("set1")).Return(nil, errors.WithStack(x.ErrNotFound)) + softwareKeyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq("set1")).Return(swKeySet, nil) + resultKeySet, err := keyManager.GetKeySet(nil, "set1") + assert.NoError(t, err) + assert.Equal(t, swKeySet, resultKeySet) + }) + + t.Run("GetKeySet_WithErrNotFoundFromSoftwareKeyManager", func(t *testing.T) { + hardwareKeyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq("set1")).Return(nil, errors.WithStack(x.ErrNotFound)) + softwareKeyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq("set1")).Return(nil, errors.WithStack(x.ErrNotFound)) + resultKeySet, err := keyManager.GetKeySet(nil, "set1") + assert.Error(t, err, "Not Found") + assert.Nil(t, resultKeySet) + }) + + t.Run("DeleteKey_FromHardwareKeyManager", func(t *testing.T) { + hardwareKeyManager.EXPECT().DeleteKey(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1")).Return(nil) + err := keyManager.DeleteKey(nil, "set1", "kid1") + assert.NoError(t, err) + }) + + t.Run("DeleteKey_WithErrorFromHardwareKeyManager", func(t *testing.T) { + hardwareKeyManager.EXPECT().DeleteKey(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1")).Return(errors.New("test")) + err := keyManager.DeleteKey(nil, "set1", "kid1") + assert.Error(t, err, "test") + }) + + t.Run("DeleteKey_WithErrNotFoundFromHardwareKeyManager", func(t *testing.T) { + hardwareKeyManager.EXPECT().DeleteKey(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1")).Return(errors.WithStack(x.ErrNotFound)) + softwareKeyManager.EXPECT().DeleteKey(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1")).Return(nil) + err := keyManager.DeleteKey(nil, "set1", "kid1") + assert.NoError(t, err) + }) + + t.Run("DeleteKey_WithErrNotFoundFromSoftwareKeyManager", func(t *testing.T) { + hardwareKeyManager.EXPECT().DeleteKey(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1")).Return(errors.WithStack(x.ErrNotFound)) + softwareKeyManager.EXPECT().DeleteKey(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1")).Return(errors.WithStack(x.ErrNotFound)) + err := keyManager.DeleteKey(nil, "set1", "kid1") + assert.Error(t, err, "Not Found") + }) + + t.Run("DeleteKeySet_FromHardwareKeyManager", func(t *testing.T) { + hardwareKeyManager.EXPECT().DeleteKeySet(gomock.Any(), gomock.Eq("set1")).Return(nil) + err := keyManager.DeleteKeySet(nil, "set1") + assert.NoError(t, err) + }) + + t.Run("DeleteKeySet_WithErrorFromHardwareKeyManager", func(t *testing.T) { + hardwareKeyManager.EXPECT().DeleteKeySet(gomock.Any(), gomock.Eq("set1")).Return(errors.New("test")) + err := keyManager.DeleteKeySet(nil, "set1") + assert.Error(t, err, "test") + }) + + t.Run("DeleteKeySet_WithErrNotFoundFromHardwareKeyManager", func(t *testing.T) { + hardwareKeyManager.EXPECT().DeleteKeySet(gomock.Any(), gomock.Eq("set1")).Return(errors.WithStack(x.ErrNotFound)) + softwareKeyManager.EXPECT().DeleteKeySet(gomock.Any(), gomock.Eq("set1")).Return(nil) + err := keyManager.DeleteKeySet(nil, "set1") + assert.NoError(t, err) + }) + + t.Run("DeleteKeySet_WithErrNotFoundFromSoftwareKeyManager", func(t *testing.T) { + hardwareKeyManager.EXPECT().DeleteKeySet(gomock.Any(), gomock.Eq("set1")).Return(errors.WithStack(x.ErrNotFound)) + softwareKeyManager.EXPECT().DeleteKeySet(gomock.Any(), gomock.Eq("set1")).Return(errors.WithStack(x.ErrNotFound)) + err := keyManager.DeleteKeySet(nil, "set1") + assert.Error(t, err, "Not Found") + }) +} diff --git a/jwk/manager_test_helpers.go b/jwk/manager_test_helpers.go index 3d50a57321e..bd12b738fd4 100644 --- a/jwk/manager_test_helpers.go +++ b/jwk/manager_test_helpers.go @@ -44,17 +44,21 @@ func RandomBytes(n int) ([]byte, error) { func canonicalizeThumbprints(js []jose.JSONWebKey) []jose.JSONWebKey { for k, v := range js { - if len(v.CertificateThumbprintSHA1) == 0 { - v.CertificateThumbprintSHA1 = nil - } - if len(v.CertificateThumbprintSHA256) == 0 { - v.CertificateThumbprintSHA256 = nil - } - js[k] = v + js[k] = canonicalizeKeyThumbprints(&v) } return js } +func canonicalizeKeyThumbprints(v *jose.JSONWebKey) jose.JSONWebKey { + if len(v.CertificateThumbprintSHA1) == 0 { + v.CertificateThumbprintSHA1 = nil + } + if len(v.CertificateThumbprintSHA256) == 0 { + v.CertificateThumbprintSHA256 = nil + } + return *v +} + func TestHelperManagerKey(m Manager, algo string, keys *jose.JSONWebKeySet, suffix string) func(t *testing.T) { pub := canonicalizeThumbprints(keys.Key("public:" + suffix)) priv := canonicalizeThumbprints(keys.Key("private:" + suffix)) @@ -85,11 +89,22 @@ func TestHelperManagerKey(m Manager, algo string, keys *jose.JSONWebKeySet, suff time.Sleep(time.Second * 2) First(pub).KeyID = "new-key-id:" + suffix + First(pub).Use = "sig" err = m.AddKey(context.TODO(), algo+"faz", First(pub)) require.NoError(t, err) - _, err = m.GetKey(context.TODO(), algo+"faz", "new-key-id:"+suffix) + got, err = m.GetKey(context.TODO(), algo+"faz", "new-key-id:"+suffix) + require.NoError(t, err) + newKey := First(got.Keys) + assert.EqualValues(t, "sig", newKey.Use) + + newKey.Use = "enc" + err = m.UpdateKey(context.TODO(), algo+"faz", newKey) require.NoError(t, err) + updated, err := m.GetKey(context.TODO(), algo+"faz", "new-key-id:"+suffix) + require.NoError(t, err) + updatedKey := First(updated.Keys) + assert.EqualValues(t, "enc", updatedKey.Use) keys, err = m.GetKeySet(context.TODO(), algo+"faz") require.NoError(t, err) @@ -121,6 +136,16 @@ func TestHelperManagerKeySet(m Manager, algo string, keys *jose.JSONWebKeySet, s assert.Equal(t, canonicalizeThumbprints(keys.Key("public:"+suffix)), canonicalizeThumbprints(got.Key("public:"+suffix))) assert.Equal(t, canonicalizeThumbprints(keys.Key("private:"+suffix)), canonicalizeThumbprints(got.Key("private:"+suffix))) + for i, _ := range got.Keys { + got.Keys[i].Use = "enc" + } + err = m.UpdateKeySet(context.TODO(), algo+"bar", got) + require.NoError(t, err) + updated, err := m.GetKeySet(context.TODO(), algo+"bar") + require.NoError(t, err) + assert.EqualValues(t, "enc", First(updated.Key("public:"+suffix)).Use) + assert.EqualValues(t, "enc", First(updated.Key("private:"+suffix)).Use) + err = m.DeleteKeySet(context.TODO(), algo+"bar") require.NoError(t, err) @@ -128,3 +153,33 @@ func TestHelperManagerKeySet(m Manager, algo string, keys *jose.JSONWebKeySet, s require.Error(t, err) } } + +func TestHelperManagerGenerateAndPersistKeySet(m Manager, alg string) func(t *testing.T) { + return func(t *testing.T) { + _, err := m.GetKeySet(context.TODO(), "foo") + require.Error(t, err) + + keys, err := m.GenerateAndPersistKeySet(context.TODO(), "foo", "bar", alg, "sig") + require.NoError(t, err) + genPub, err := FindPublicKey(keys) + require.NoError(t, err) + genPriv, err := FindPrivateKey(keys) + require.NoError(t, err) + + got, err := m.GetKeySet(context.TODO(), "foo") + require.NoError(t, err) + gotPub, err := FindPublicKey(got) + require.NoError(t, err) + gotPriv, err := FindPrivateKey(got) + require.NoError(t, err) + + assert.Equal(t, canonicalizeKeyThumbprints(genPub), canonicalizeKeyThumbprints(gotPub)) + assert.Equal(t, canonicalizeKeyThumbprints(genPriv), canonicalizeKeyThumbprints(gotPriv)) + + err = m.DeleteKeySet(context.TODO(), "foo") + require.NoError(t, err) + + _, err = m.GetKeySet(context.TODO(), "foo") + require.Error(t, err) + } +} diff --git a/jwk/registry.go b/jwk/registry.go index d659d561356..202506a1eb2 100644 --- a/jwk/registry.go +++ b/jwk/registry.go @@ -12,6 +12,7 @@ type InternalRegistry interface { type Registry interface { KeyManager() Manager + SoftwareKeyManager() Manager KeyGenerators() map[string]KeyGenerator KeyCipher() *AEAD } diff --git a/jwk/registry_mock_test.go b/jwk/registry_mock_test.go new file mode 100644 index 00000000000..0576c1e9545 --- /dev/null +++ b/jwk/registry_mock_test.go @@ -0,0 +1,217 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: jwk/registry.go + +// Package mock_jwk is a generated GoMock package. +package jwk_test + +import ( + reflect "reflect" + + "github.com/ory/hydra/jwk" + + gomock "github.com/golang/mock/gomock" + + herodot "github.com/ory/herodot" + + logrusx "github.com/ory/x/logrusx" +) + +// MockInternalRegistry is a mock of InternalRegistry interface. +type MockInternalRegistry struct { + ctrl *gomock.Controller + recorder *MockInternalRegistryMockRecorder +} + +// MockInternalRegistryMockRecorder is the mock recorder for MockInternalRegistry. +type MockInternalRegistryMockRecorder struct { + mock *MockInternalRegistry +} + +// NewMockInternalRegistry creates a new mock instance. +func NewMockInternalRegistry(ctrl *gomock.Controller) *MockInternalRegistry { + mock := &MockInternalRegistry{ctrl: ctrl} + mock.recorder = &MockInternalRegistryMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockInternalRegistry) EXPECT() *MockInternalRegistryMockRecorder { + return m.recorder +} + +// AuditLogger mocks base method. +func (m *MockInternalRegistry) AuditLogger() *logrusx.Logger { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AuditLogger") + ret0, _ := ret[0].(*logrusx.Logger) + return ret0 +} + +// AuditLogger indicates an expected call of AuditLogger. +func (mr *MockInternalRegistryMockRecorder) AuditLogger() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AuditLogger", reflect.TypeOf((*MockInternalRegistry)(nil).AuditLogger)) +} + +// KeyCipher mocks base method. +func (m *MockInternalRegistry) KeyCipher() *jwk.AEAD { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "KeyCipher") + ret0, _ := ret[0].(*jwk.AEAD) + return ret0 +} + +// KeyCipher indicates an expected call of KeyCipher. +func (mr *MockInternalRegistryMockRecorder) KeyCipher() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeyCipher", reflect.TypeOf((*MockInternalRegistry)(nil).KeyCipher)) +} + +// KeyGenerators mocks base method. +func (m *MockInternalRegistry) KeyGenerators() map[string]jwk.KeyGenerator { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "KeyGenerators") + ret0, _ := ret[0].(map[string]jwk.KeyGenerator) + return ret0 +} + +// KeyGenerators indicates an expected call of KeyGenerators. +func (mr *MockInternalRegistryMockRecorder) KeyGenerators() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeyGenerators", reflect.TypeOf((*MockInternalRegistry)(nil).KeyGenerators)) +} + +// KeyManager mocks base method. +func (m *MockInternalRegistry) KeyManager() jwk.Manager { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "KeyManager") + ret0, _ := ret[0].(jwk.Manager) + return ret0 +} + +// KeyManager indicates an expected call of KeyManager. +func (mr *MockInternalRegistryMockRecorder) KeyManager() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeyManager", reflect.TypeOf((*MockInternalRegistry)(nil).KeyManager)) +} + +// Logger mocks base method. +func (m *MockInternalRegistry) Logger() *logrusx.Logger { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Logger") + ret0, _ := ret[0].(*logrusx.Logger) + return ret0 +} + +// Logger indicates an expected call of Logger. +func (mr *MockInternalRegistryMockRecorder) Logger() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Logger", reflect.TypeOf((*MockInternalRegistry)(nil).Logger)) +} + +// SoftwareKeyManager mocks base method. +func (m *MockInternalRegistry) SoftwareKeyManager() jwk.Manager { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SoftwareKeyManager") + ret0, _ := ret[0].(jwk.Manager) + return ret0 +} + +// SoftwareKeyManager indicates an expected call of SoftwareKeyManager. +func (mr *MockInternalRegistryMockRecorder) SoftwareKeyManager() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SoftwareKeyManager", reflect.TypeOf((*MockInternalRegistry)(nil).SoftwareKeyManager)) +} + +// Writer mocks base method. +func (m *MockInternalRegistry) Writer() herodot.Writer { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Writer") + ret0, _ := ret[0].(herodot.Writer) + return ret0 +} + +// Writer indicates an expected call of Writer. +func (mr *MockInternalRegistryMockRecorder) Writer() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Writer", reflect.TypeOf((*MockInternalRegistry)(nil).Writer)) +} + +// MockRegistry is a mock of Registry interface. +type MockRegistry struct { + ctrl *gomock.Controller + recorder *MockRegistryMockRecorder +} + +// MockRegistryMockRecorder is the mock recorder for MockRegistry. +type MockRegistryMockRecorder struct { + mock *MockRegistry +} + +// NewMockRegistry creates a new mock instance. +func NewMockRegistry(ctrl *gomock.Controller) *MockRegistry { + mock := &MockRegistry{ctrl: ctrl} + mock.recorder = &MockRegistryMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRegistry) EXPECT() *MockRegistryMockRecorder { + return m.recorder +} + +// KeyCipher mocks base method. +func (m *MockRegistry) KeyCipher() *jwk.AEAD { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "KeyCipher") + ret0, _ := ret[0].(*jwk.AEAD) + return ret0 +} + +// KeyCipher indicates an expected call of KeyCipher. +func (mr *MockRegistryMockRecorder) KeyCipher() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeyCipher", reflect.TypeOf((*MockRegistry)(nil).KeyCipher)) +} + +// KeyGenerators mocks base method. +func (m *MockRegistry) KeyGenerators() map[string]jwk.KeyGenerator { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "KeyGenerators") + ret0, _ := ret[0].(map[string]jwk.KeyGenerator) + return ret0 +} + +// KeyGenerators indicates an expected call of KeyGenerators. +func (mr *MockRegistryMockRecorder) KeyGenerators() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeyGenerators", reflect.TypeOf((*MockRegistry)(nil).KeyGenerators)) +} + +// KeyManager mocks base method. +func (m *MockRegistry) KeyManager() jwk.Manager { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "KeyManager") + ret0, _ := ret[0].(jwk.Manager) + return ret0 +} + +// KeyManager indicates an expected call of KeyManager. +func (mr *MockRegistryMockRecorder) KeyManager() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeyManager", reflect.TypeOf((*MockRegistry)(nil).KeyManager)) +} + +// SoftwareKeyManager mocks base method. +func (m *MockRegistry) SoftwareKeyManager() jwk.Manager { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SoftwareKeyManager") + ret0, _ := ret[0].(jwk.Manager) + return ret0 +} + +// SoftwareKeyManager indicates an expected call of SoftwareKeyManager. +func (mr *MockRegistryMockRecorder) SoftwareKeyManager() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SoftwareKeyManager", reflect.TypeOf((*MockRegistry)(nil).SoftwareKeyManager)) +} diff --git a/jwk/sdk_test.go b/jwk/sdk_test.go index 30d73df4fee..c9779ec7abf 100644 --- a/jwk/sdk_test.go +++ b/jwk/sdk_test.go @@ -53,49 +53,67 @@ func TestJWKSDK(t *testing.T) { server := httptest.NewServer(router) sdk := client.NewHTTPClientWithConfig(nil, &client.TransportConfig{Schemes: []string{"http"}, Host: urlx.ParseOrPanic(server.URL).Host}) + expectedPublicKid := "public:key-bar" + if conf.HsmEnabled() { + expectedPublicKid = "key-bar" + } + t.Run("JSON Web Key", func(t *testing.T) { t.Run("CreateJwkSetKey", func(t *testing.T) { // Create a key called set-foo resultKeys, err := sdk.Admin.CreateJSONWebKeySet(admin.NewCreateJSONWebKeySetParams().WithSet("set-foo").WithBody(&models.JSONWebKeySetGeneratorRequest{ - Alg: pointerx.String("HS256"), + Alg: pointerx.String("RS256"), Kid: pointerx.String("key-bar"), Use: pointerx.String("sig"), })) require.NoError(t, err) - require.Len(t, resultKeys.Payload.Keys, 1) - assert.Equal(t, "key-bar", *resultKeys.Payload.Keys[0].Kid) - assert.Equal(t, "HS256", *resultKeys.Payload.Keys[0].Alg) - assert.Equal(t, "sig", *resultKeys.Payload.Keys[0].Use) + if conf.HsmEnabled() { + require.Len(t, resultKeys.Payload.Keys, 1) + assert.Equal(t, expectedPublicKid, *resultKeys.Payload.Keys[0].Kid) + assert.Equal(t, "RS256", *resultKeys.Payload.Keys[0].Alg) + assert.Equal(t, "sig", *resultKeys.Payload.Keys[0].Use) + } else { + require.Len(t, resultKeys.Payload.Keys, 2) + assert.Equal(t, "private:key-bar", *resultKeys.Payload.Keys[0].Kid) + assert.Equal(t, "RS256", *resultKeys.Payload.Keys[0].Alg) + assert.Equal(t, "sig", *resultKeys.Payload.Keys[0].Use) + assert.Equal(t, expectedPublicKid, *resultKeys.Payload.Keys[1].Kid) + assert.Equal(t, "RS256", *resultKeys.Payload.Keys[1].Alg) + assert.Equal(t, "sig", *resultKeys.Payload.Keys[1].Use) + } }) var resultKeys *models.JSONWebKeySet t.Run("GetJwkSetKey after create", func(t *testing.T) { - result, err := sdk.Admin.GetJSONWebKey(admin.NewGetJSONWebKeyParams().WithKid("key-bar").WithSet("set-foo")) + result, err := sdk.Admin.GetJSONWebKey(admin.NewGetJSONWebKeyParams().WithKid(expectedPublicKid).WithSet("set-foo")) require.NoError(t, err) require.Len(t, result.Payload.Keys, 1) - require.Equal(t, "key-bar", *result.Payload.Keys[0].Kid) - require.Equal(t, "HS256", *result.Payload.Keys[0].Alg) + require.Equal(t, expectedPublicKid, *result.Payload.Keys[0].Kid) + require.Equal(t, "RS256", *result.Payload.Keys[0].Alg) resultKeys = result.Payload }) t.Run("UpdateJwkSetKey", func(t *testing.T) { + if conf.HsmEnabled() { + t.Skip("Skipping test. Keys cannot be updated when Hardware Security Module is enabled") + } require.Len(t, resultKeys.Keys, 1) - resultKeys.Keys[0].Alg = pointerx.String("RS256") + resultKeys.Keys[0].Alg = pointerx.String("ES256") - resultKey, err := sdk.Admin.UpdateJSONWebKey(admin.NewUpdateJSONWebKeyParams().WithKid("key-bar").WithSet("set-foo").WithBody(resultKeys.Keys[0])) + resultKey, err := sdk.Admin.UpdateJSONWebKey(admin.NewUpdateJSONWebKeyParams().WithKid(expectedPublicKid).WithSet("set-foo").WithBody(resultKeys.Keys[0])) require.NoError(t, err) - assert.Equal(t, "key-bar", *resultKey.Payload.Kid) - assert.Equal(t, "RS256", *resultKey.Payload.Alg) + assert.Equal(t, expectedPublicKid, *resultKey.Payload.Kid) + assert.Equal(t, "ES256", *resultKey.Payload.Alg) }) t.Run("DeleteJwkSetKey after delete", func(t *testing.T) { - _, err := sdk.Admin.DeleteJSONWebKey(admin.NewDeleteJSONWebKeyParams().WithKid("key-bar").WithSet("set-foo")) + _, err := sdk.Admin.DeleteJSONWebKey(admin.NewDeleteJSONWebKeyParams().WithKid(expectedPublicKid).WithSet("set-foo")) require.NoError(t, err) }) t.Run("GetJwkSetKey after delete", func(t *testing.T) { - _, err := sdk.Admin.GetJSONWebKey(admin.NewGetJSONWebKeyParams().WithKid("key-bar").WithSet("set-foo")) + _, err := sdk.Admin.GetJSONWebKey(admin.NewGetJSONWebKeyParams().WithKid(expectedPublicKid).WithSet("set-foo")) require.Error(t, err) }) @@ -104,33 +122,54 @@ func TestJWKSDK(t *testing.T) { t.Run("JWK Set", func(t *testing.T) { t.Run("CreateJwkSetKey", func(t *testing.T) { resultKeys, err := sdk.Admin.CreateJSONWebKeySet(admin.NewCreateJSONWebKeySetParams().WithSet("set-foo2").WithBody(&models.JSONWebKeySetGeneratorRequest{ - Alg: pointerx.String("HS256"), + Alg: pointerx.String("RS256"), Kid: pointerx.String("key-bar"), })) require.NoError(t, err) - - require.Len(t, resultKeys.Payload.Keys, 1) - assert.Equal(t, "key-bar", *resultKeys.Payload.Keys[0].Kid) - assert.Equal(t, "HS256", *resultKeys.Payload.Keys[0].Alg) + if conf.HsmEnabled() { + require.Len(t, resultKeys.Payload.Keys, 1) + assert.Equal(t, expectedPublicKid, *resultKeys.Payload.Keys[0].Kid) + assert.Equal(t, "RS256", *resultKeys.Payload.Keys[0].Alg) + } else { + require.Len(t, resultKeys.Payload.Keys, 2) + assert.Equal(t, "private:key-bar", *resultKeys.Payload.Keys[0].Kid) + assert.Equal(t, "RS256", *resultKeys.Payload.Keys[0].Alg) + assert.Equal(t, expectedPublicKid, *resultKeys.Payload.Keys[1].Kid) + assert.Equal(t, "RS256", *resultKeys.Payload.Keys[1].Alg) + } }) resultKeys, err := sdk.Admin.GetJSONWebKeySet(admin.NewGetJSONWebKeySetParams().WithSet("set-foo2")) t.Run("GetJwkSet after create", func(t *testing.T) { require.NoError(t, err) - require.Len(t, resultKeys.Payload.Keys, 1) - assert.Equal(t, "key-bar", *resultKeys.Payload.Keys[0].Kid) - assert.Equal(t, "HS256", *resultKeys.Payload.Keys[0].Alg) + if conf.HsmEnabled() { + require.Len(t, resultKeys.Payload.Keys, 1) + assert.Equal(t, expectedPublicKid, *resultKeys.Payload.Keys[0].Kid) + assert.Equal(t, "RS256", *resultKeys.Payload.Keys[0].Alg) + } else { + require.Len(t, resultKeys.Payload.Keys, 2) + assert.Equal(t, expectedPublicKid, *resultKeys.Payload.Keys[0].Kid) + assert.Equal(t, "RS256", *resultKeys.Payload.Keys[0].Alg) + assert.Equal(t, "private:key-bar", *resultKeys.Payload.Keys[1].Kid) + assert.Equal(t, "RS256", *resultKeys.Payload.Keys[1].Alg) + } }) t.Run("UpdateJwkSet", func(t *testing.T) { - require.Len(t, resultKeys.Payload.Keys, 1) - resultKeys.Payload.Keys[0].Alg = pointerx.String("RS256") + if conf.HsmEnabled() { + t.Skip("Skipping test. Keys cannot be updated when Hardware Security Module is enabled") + } + require.Len(t, resultKeys.Payload.Keys, 2) + resultKeys.Payload.Keys[0].Alg = pointerx.String("ES256") + resultKeys.Payload.Keys[1].Alg = pointerx.String("ES256") result, err := sdk.Admin.UpdateJSONWebKeySet(admin.NewUpdateJSONWebKeySetParams().WithSet("set-foo2").WithBody(resultKeys.Payload)) require.NoError(t, err) - require.Len(t, result.Payload.Keys, 1) - assert.Equal(t, "key-bar", *result.Payload.Keys[0].Kid) - assert.Equal(t, "RS256", *result.Payload.Keys[0].Alg) + require.Len(t, result.Payload.Keys, 2) + assert.Equal(t, expectedPublicKid, *result.Payload.Keys[0].Kid) + assert.Equal(t, "ES256", *result.Payload.Keys[0].Alg) + assert.Equal(t, "private:key-bar", *result.Payload.Keys[1].Kid) + assert.Equal(t, "ES256", *result.Payload.Keys[1].Alg) }) t.Run("DeleteJwkSet", func(t *testing.T) { @@ -144,7 +183,7 @@ func TestJWKSDK(t *testing.T) { }) t.Run("GetJwkSetKey after delete", func(t *testing.T) { - _, err := sdk.Admin.GetJSONWebKey(admin.NewGetJSONWebKeyParams().WithSet("set-foo2").WithKid("key-bar")) + _, err := sdk.Admin.GetJSONWebKey(admin.NewGetJSONWebKeyParams().WithSet("set-foo2").WithKid(expectedPublicKid)) require.Error(t, err) }) }) diff --git a/oauth2/handler_test.go b/oauth2/handler_test.go index b1058cd2d86..f047719b76a 100644 --- a/oauth2/handler_test.go +++ b/oauth2/handler_test.go @@ -370,7 +370,7 @@ func TestUserinfo(t *testing.T) { keys, err := reg.KeyManager().GetKeySet(context.Background(), x.OpenIDConnectKeyName) require.NoError(t, err) t.Logf("%+v", keys) - key, err := jwk.FindKeyByPrefix(keys, "public") + key, err := jwk.FindPublicKey(keys) return jwk.MustRSAPublic(key), nil }) require.NoError(t, err) diff --git a/persistence/sql/persister.go b/persistence/sql/persister.go index dda45afc8a0..48b891f6395 100644 --- a/persistence/sql/persister.go +++ b/persistence/sql/persister.go @@ -41,6 +41,7 @@ type ( Dependencies interface { ClientHasher() fosite.Hasher KeyCipher() *jwk.AEAD + KeyGenerators() map[string]jwk.KeyGenerator x.RegistryLogger x.TracingProvider } diff --git a/persistence/sql/persister_jwk.go b/persistence/sql/persister_jwk.go index d88f45bf809..2f787249843 100644 --- a/persistence/sql/persister_jwk.go +++ b/persistence/sql/persister_jwk.go @@ -18,6 +18,25 @@ import ( var _ jwk.Manager = &Persister{} +func (p *Persister) GenerateAndPersistKeySet(ctx context.Context, set, kid, alg, use string) (*jose.JSONWebKeySet, error) { + generator, found := p.r.KeyGenerators()[alg] + if !found { + return nil, errorsx.WithStack(jwk.ErrUnsupportedKeyAlgorithm) + } + + keys, err := generator.Generate(kid, use) + if err != nil { + return nil, err + } + + err = p.AddKeySet(ctx, set, keys) + if err != nil { + return nil, err + } + + return keys, nil +} + func (p *Persister) AddKey(ctx context.Context, set string, key *jose.JSONWebKey) error { out, err := json.Marshal(key) if err != nil { @@ -63,6 +82,30 @@ func (p *Persister) AddKeySet(ctx context.Context, set string, keys *jose.JSONWe }) } +func (p *Persister) UpdateKey(ctx context.Context, set string, key *jose.JSONWebKey) error { + return p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error { + if err := p.DeleteKey(ctx, set, key.KeyID); err != nil { + return err + } + if err := p.AddKey(ctx, set, key); err != nil { + return err + } + return nil + }) +} + +func (p *Persister) UpdateKeySet(ctx context.Context, set string, keySet *jose.JSONWebKeySet) error { + return p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error { + if err := p.DeleteKeySet(ctx, set); err != nil { + return err + } + if err := p.AddKeySet(ctx, set, keySet); err != nil { + return err + } + return nil + }) +} + func (p *Persister) GetKey(ctx context.Context, set, kid string) (*jose.JSONWebKeySet, error) { var j jwk.SQLData if err := p.Connection(ctx). diff --git a/persistence/sql/persister_test.go b/persistence/sql/persister_test.go index 4ffa1f516e3..9cf4f3865cf 100644 --- a/persistence/sql/persister_test.go +++ b/persistence/sql/persister_test.go @@ -1,12 +1,16 @@ package sql_test import ( + "context" "testing" + "github.com/pkg/errors" + "github.com/pborman/uuid" "github.com/ory/hydra/oauth2/trust" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/ory/hydra/internal/testhelpers" @@ -45,23 +49,43 @@ func TestManagers(t *testing.T) { t.Run("package=consent/janitor="+k, testhelpers.JanitorTests(m.Config(), m.ConsentManager(), m.ClientManager(), m.OAuth2Storage())) t.Run("package=jwk/manager="+k, func(t *testing.T) { - testGenerators := new(driver.RegistryBase).KeyGenerators() - for algo, testGenerator := range testGenerators { - t.Run("TestManagerKey", func(t *testing.T) { - ks, err := testGenerator.Generate("TestManagerKey", "sig") - require.NoError(t, err) - - jwk.TestHelperManagerKey(m.KeyManager(), algo, ks, uuid.New()) - }) - - t.Run("TestManagerKeySet", func(t *testing.T) { - ks, err := testGenerator.Generate("TestManagerKeySet", "sig") - require.NoError(t, err) - ks.Key("private") - - jwk.TestHelperManagerKeySet(m.KeyManager(), algo, ks, uuid.New()) + keyGenerators := new(driver.RegistryBase).KeyGenerators() + assert.Equalf(t, 6, len(keyGenerators), "Test for key generator is not implemented") + + for _, tc := range []struct { + keyGenerator jwk.KeyGenerator + alg string + skip bool + }{ + {keyGenerator: keyGenerators["RS256"], alg: "RS256", skip: false}, + {keyGenerator: keyGenerators["ES256"], alg: "ES256", skip: false}, + {keyGenerator: keyGenerators["ES512"], alg: "ES512", skip: false}, + {keyGenerator: keyGenerators["HS256"], alg: "HS256", skip: true}, + {keyGenerator: keyGenerators["HS512"], alg: "HS512", skip: true}, + {keyGenerator: keyGenerators["EdDSA"], alg: "EdDSA", skip: m.Config().HsmEnabled()}, + } { + t.Run("key_generator="+tc.alg, func(t *testing.T) { + if tc.skip { + t.Skipf("Skipping test. Not applicable for alg: %s", tc.alg) + } + if m.Config().HsmEnabled() { + t.Run("TestManagerGenerateAndPersistKeySet", jwk.TestHelperManagerGenerateAndPersistKeySet(m.KeyManager(), tc.alg)) + } else { + kid := uuid.New() + ks, err := tc.keyGenerator.Generate(kid, "sig") + require.NoError(t, err) + t.Run("TestManagerKey", jwk.TestHelperManagerKey(m.KeyManager(), tc.alg, ks, kid)) + t.Run("TestManagerKeySet", jwk.TestHelperManagerKeySet(m.KeyManager(), tc.alg, ks, kid)) + t.Run("TestManagerGenerateAndPersistKeySet", jwk.TestHelperManagerGenerateAndPersistKeySet(m.KeyManager(), tc.alg)) + } }) } + + t.Run("TestManagerGenerateAndPersistKeySetWithUnsupportedKeyGenerator", func(t *testing.T) { + _, err := m.KeyManager().GenerateAndPersistKeySet(context.TODO(), "foo", "bar", "UNKNOWN", "sig") + require.Error(t, err) + assert.IsType(t, errors.WithStack(jwk.ErrUnsupportedKeyAlgorithm), err) + }) }) t.Run("package=grant/trust/manager="+k, func(t *testing.T) { diff --git a/quickstart-hsm.yml b/quickstart-hsm.yml new file mode 100644 index 00000000000..c0314c10cc8 --- /dev/null +++ b/quickstart-hsm.yml @@ -0,0 +1,25 @@ +########################################################################### +####### FOR DEMONSTRATION PURPOSES ONLY ####### +########################################################################### +# # +# If you have not yet read the tutorial, do so now: # +# https://www.ory.sh/docs/hydra/5min-tutorial # +# # +# This set up is only for demonstration purposes. The login # +# endpoint can only be used if you follow the steps in the tutorial. # +# # +########################################################################### + +version: '3.7' + +services: + + hydra: + build: + context: . + dockerfile: .docker/Dockerfile-hsm + environment: + - HSM_ENABLED=true + - HSM_LIBRARY=/usr/lib/softhsm/libsofthsm2.so + - HSM_TOKEN_LABEL=hydra + - HSM_PIN=1234 diff --git a/spec/config.json b/spec/config.json index dbf0733f5be..364fa2c51cd 100644 --- a/spec/config.json +++ b/spec/config.json @@ -451,6 +451,32 @@ "type": "string", "description": "Sets the data source name. This configures the backend where ORY Hydra persists data. If dsn is \"memory\", data will be written to memory and is lost when you restart this instance. ORY Hydra supports popular SQL databases. For more detailed configuration information go to: https://www.ory.sh/docs/hydra/dependencies-environment#sql" }, + "hsm": { + "type": "object", + "additionalProperties": false, + "description": "Configures Hardware Security Module.", + "properties": { + "enabled": { + "type": "boolean" + }, + "library": { + "type": "string", + "description": "Full path (including file extension) of the HSM vendor PKCS#11 library" + }, + "pin": { + "type": "string", + "description": "PIN code for token operations" + }, + "slot": { + "type": "integer", + "description": "Slot ID of the token to use (if label is not specified)" + }, + "token_label": { + "type": "string", + "description": "Label of the token to use (if slot is not specified). If both slot and label are set, token label takes preference over slot. In this case first slot, that contains this label is used." + } + } + }, "webfinger": { "type": "object", "additionalProperties": false,