From 08e056fc7a440f266be76f2259642cedea5b1c50 Mon Sep 17 00:00:00 2001 From: Max Tropets Date: Tue, 3 Dec 2024 19:48:24 +0000 Subject: [PATCH] WIP raw key JWT auth . --- doc/schemas/gov_openapi.json | 56 +----- include/ccf/crypto/jwk.h | 18 +- include/ccf/crypto/rsa_public_key.h | 7 + .../ccf/endpoints/authentication/jwt_auth.h | 4 +- include/ccf/service/tables/jwt.h | 9 +- samples/constitutions/default/actions.js | 25 ++- src/crypto/openssl/rsa_public_key.cpp | 16 ++ src/crypto/openssl/rsa_public_key.h | 7 + src/endpoints/authentication/jwt_auth.cpp | 52 ++++-- src/node/gov/handlers/service_state.h | 12 +- src/node/rpc/jwt_management.h | 169 +++++++++++------- src/node/rpc/member_frontend.h | 32 ---- tests/infra/jwt_issuer.py | 53 +++--- .../custom_authorization.py | 66 +++++++ tests/jwt_test.py | 49 ++--- 15 files changed, 330 insertions(+), 245 deletions(-) diff --git a/doc/schemas/gov_openapi.json b/doc/schemas/gov_openapi.json index 90ab5ed30d61..1c3102936e97 100644 --- a/doc/schemas/gov_openapi.json +++ b/doc/schemas/gov_openapi.json @@ -291,27 +291,6 @@ }, "type": "object" }, - "KeyIdInfo": { - "properties": { - "cert": { - "$ref": "#/components/schemas/Pem" - }, - "issuer": { - "$ref": "#/components/schemas/string" - } - }, - "required": [ - "issuer", - "cert" - ], - "type": "object" - }, - "KeyIdInfo_array": { - "items": { - "$ref": "#/components/schemas/KeyIdInfo" - }, - "type": "array" - }, "MDType": { "enum": [ "NONE", @@ -808,10 +787,12 @@ }, "issuer": { "$ref": "#/components/schemas/string" + }, + "public_key": { + "$ref": "#/components/schemas/base64string" } }, "required": [ - "cert", "issuer" ], "type": "object" @@ -1222,12 +1203,6 @@ }, "type": "object" }, - "string_to_KeyIdInfo_array": { - "additionalProperties": { - "$ref": "#/components/schemas/KeyIdInfo_array" - }, - "type": "object" - }, "string_to_OpenIDJWKMetadata_array": { "additionalProperties": { "$ref": "#/components/schemas/OpenIDJWKMetadata_array" @@ -1473,31 +1448,6 @@ } ] }, - "/gov/jwt_keys/all": { - "get": { - "deprecated": true, - "description": "This endpoint is deprecated from 5.0.0. It is replaced by POST /gov/service/jwk", - "operationId": "GetGovJwtKeysAll", - "responses": { - "200": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/string_to_KeyIdInfo_array" - } - } - }, - "description": "Default response description" - }, - "default": { - "$ref": "#/components/responses/default" - } - }, - "x-ccf-forwarding": { - "$ref": "#/components/x-ccf-forwarding/always" - } - } - }, "/gov/kv/constitution": { "get": { "deprecated": true, diff --git a/include/ccf/crypto/jwk.h b/include/ccf/crypto/jwk.h index 1b4886cb1a22..fa3200f39948 100644 --- a/include/ccf/crypto/jwk.h +++ b/include/ccf/crypto/jwk.h @@ -27,13 +27,27 @@ namespace ccf::crypto JsonWebKeyType kty; std::optional kid = std::nullopt; std::optional> x5c = std::nullopt; - std::optional issuer = std::nullopt; bool operator==(const JsonWebKey&) const = default; }; DECLARE_JSON_TYPE_WITH_OPTIONAL_FIELDS(JsonWebKey); DECLARE_JSON_REQUIRED_FIELDS(JsonWebKey, kty); - DECLARE_JSON_OPTIONAL_FIELDS(JsonWebKey, kid, x5c, issuer); + DECLARE_JSON_OPTIONAL_FIELDS(JsonWebKey, kid, x5c); + + struct JsonWebKeyExtended + { + JsonWebKeyType kty; + std::optional kid = std::nullopt; + std::optional> x5c = std::nullopt; + std::optional n = std::nullopt; + std::optional e = std::nullopt; + std::optional issuer = std::nullopt; + + bool operator==(const JsonWebKeyExtended&) const = default; + }; + DECLARE_JSON_TYPE_WITH_OPTIONAL_FIELDS(JsonWebKeyExtended); + DECLARE_JSON_REQUIRED_FIELDS(JsonWebKeyExtended, kty); + DECLARE_JSON_OPTIONAL_FIELDS(JsonWebKeyExtended, kid, x5c, n, e, issuer); enum class JsonWebKeyECCurve { diff --git a/include/ccf/crypto/rsa_public_key.h b/include/ccf/crypto/rsa_public_key.h index cd62eba0e7f4..1fcd81dc6d43 100644 --- a/include/ccf/crypto/rsa_public_key.h +++ b/include/ccf/crypto/rsa_public_key.h @@ -84,6 +84,13 @@ namespace ccf::crypto MDType md_type = MDType::NONE, size_t salt_legth = 0) = 0; + virtual bool verify_pkcs1( + const uint8_t* contents, + size_t contents_size, + const uint8_t* signature, + size_t signature_size, + MDType md_type = MDType::NONE) = 0; + struct Components { std::vector n; diff --git a/include/ccf/endpoints/authentication/jwt_auth.h b/include/ccf/endpoints/authentication/jwt_auth.h index 3a44ee55a73a..70d4f3f7c813 100644 --- a/include/ccf/endpoints/authentication/jwt_auth.h +++ b/include/ccf/endpoints/authentication/jwt_auth.h @@ -17,7 +17,7 @@ namespace ccf nlohmann::json payload; }; - struct VerifiersCache; + struct PublicKeysCache; bool validate_issuer( const std::string& iss, @@ -28,7 +28,7 @@ namespace ccf { protected: static const OpenAPISecuritySchema security_schema; - std::unique_ptr verifiers; + std::unique_ptr keys_cache; public: static constexpr auto SECURITY_SCHEME_NAME = "jwt"; diff --git a/include/ccf/service/tables/jwt.h b/include/ccf/service/tables/jwt.h index 23ebe5268499..87c4c9eb49b2 100644 --- a/include/ccf/service/tables/jwt.h +++ b/include/ccf/service/tables/jwt.h @@ -37,16 +37,17 @@ namespace ccf using JwtIssuer = std::string; using JwtKeyId = std::string; using Cert = std::vector; + using PublicKey = std::vector; struct OpenIDJWKMetadata { - Cert cert; + std::optional public_key; JwtIssuer issuer; std::optional constraint; }; DECLARE_JSON_TYPE_WITH_OPTIONAL_FIELDS(OpenIDJWKMetadata); - DECLARE_JSON_REQUIRED_FIELDS(OpenIDJWKMetadata, cert, issuer); - DECLARE_JSON_OPTIONAL_FIELDS(OpenIDJWKMetadata, constraint); + DECLARE_JSON_REQUIRED_FIELDS(OpenIDJWKMetadata, issuer); + DECLARE_JSON_OPTIONAL_FIELDS(OpenIDJWKMetadata, public_key, constraint); using JwtIssuers = ServiceMap; using JwtPublicSigningKeys = @@ -75,7 +76,7 @@ namespace ccf struct JsonWebKeySet { - std::vector keys; + std::vector keys; bool operator!=(const JsonWebKeySet& rhs) const { diff --git a/samples/constitutions/default/actions.js b/samples/constitutions/default/actions.js index 654ecc065326..1d1d405c6782 100644 --- a/samples/constitutions/default/actions.js +++ b/samples/constitutions/default/actions.js @@ -130,15 +130,22 @@ function checkJwks(value, field) { for (const [i, jwk] of value.keys.entries()) { checkType(jwk.kid, "string", `${field}.keys[${i}].kid`); checkType(jwk.kty, "string", `${field}.keys[${i}].kty`); - checkType(jwk.x5c, "array", `${field}.keys[${i}].x5c`); - checkLength(jwk.x5c, 1, null, `${field}.keys[${i}].x5c`); - for (const [j, b64der] of jwk.x5c.entries()) { - checkType(b64der, "string", `${field}.keys[${i}].x5c[${j}]`); - const pem = - "-----BEGIN CERTIFICATE-----\n" + - b64der + - "\n-----END CERTIFICATE-----"; - checkX509CertBundle(pem, `${field}.keys[${i}].x5c[${j}]`); + if (jwk.x5c) { + checkType(jwk.x5c, "array", `${field}.keys[${i}].x5c`); + checkLength(jwk.x5c, 1, null, `${field}.keys[${i}].x5c`); + for (const [j, b64der] of jwk.x5c.entries()) { + checkType(b64der, "string", `${field}.keys[${i}].x5c[${j}]`); + const pem = + "-----BEGIN CERTIFICATE-----\n" + + b64der + + "\n-----END CERTIFICATE-----"; + checkX509CertBundle(pem, `${field}.keys[${i}].x5c[${j}]`); + } + } else if (jwk.n && jwk.e) { + checkType(jwk.n, "string", `${field}.keys[${i}].n`); + checkType(jwk.e, "string", `${field}.keys[${i}].e`); + } else { + throw new Error("JWK must contain either x5c or n and e"); } } } diff --git a/src/crypto/openssl/rsa_public_key.cpp b/src/crypto/openssl/rsa_public_key.cpp index b8fb2f61be58..a8bbec19f319 100644 --- a/src/crypto/openssl/rsa_public_key.cpp +++ b/src/crypto/openssl/rsa_public_key.cpp @@ -208,6 +208,22 @@ namespace ccf::crypto pctx, signature, signature_size, hash.data(), hash.size()) == 1; } + bool RSAPublicKey_OpenSSL::verify_pkcs1( + const uint8_t* contents, + size_t contents_size, + const uint8_t* signature, + size_t signature_size, + MDType md_type) + { + auto hash = OpenSSLHashProvider().Hash(contents, contents_size, md_type); + Unique_EVP_PKEY_CTX pctx(key); + CHECK1(EVP_PKEY_verify_init(pctx)); + CHECK1(EVP_PKEY_CTX_set_rsa_padding(pctx, RSA_PKCS1_PADDING)); + CHECK1(EVP_PKEY_CTX_set_signature_md(pctx, get_md_type(md_type))); + return EVP_PKEY_verify( + pctx, signature, signature_size, hash.data(), hash.size()) == 1; + } + std::vector RSAPublicKey_OpenSSL::bn_bytes(const BIGNUM* bn) { std::vector r(BN_num_bytes(bn)); diff --git a/src/crypto/openssl/rsa_public_key.h b/src/crypto/openssl/rsa_public_key.h index 061ba053ad80..abe43fcf758a 100644 --- a/src/crypto/openssl/rsa_public_key.h +++ b/src/crypto/openssl/rsa_public_key.h @@ -55,6 +55,13 @@ namespace ccf::crypto MDType md_type = MDType::NONE, size_t salt_length = 0) override; + virtual bool verify_pkcs1( + const uint8_t* contents, + size_t contents_size, + const uint8_t* signature, + size_t signature_size, + MDType md_type = MDType::NONE) override; + virtual Components components() const override; static std::vector bn_bytes(const BIGNUM* bn); diff --git a/src/endpoints/authentication/jwt_auth.cpp b/src/endpoints/authentication/jwt_auth.cpp index 05ceb862ff2d..0b58a91188ef 100644 --- a/src/endpoints/authentication/jwt_auth.cpp +++ b/src/endpoints/authentication/jwt_auth.cpp @@ -3,6 +3,7 @@ #include "ccf/endpoints/authentication/jwt_auth.h" +#include "ccf/crypto/rsa_key_pair.h" #include "ccf/ds/nonstd.h" #include "ccf/pal/locking.h" #include "ccf/rpc_context.h" @@ -82,26 +83,24 @@ namespace ccf return tenant_id && tid && *tid == *tenant_id; } - struct VerifiersCache + struct PublicKeysCache { - static constexpr size_t DEFAULT_MAX_VERIFIERS = 10; + static constexpr size_t DEFAULT_MAX_KEYS = 10; using DER = std::vector; - ccf::pal::Mutex verifiers_lock; - LRU verifiers; + ccf::pal::Mutex keys_lock; + LRU keys; - VerifiersCache(size_t max_verifiers = DEFAULT_MAX_VERIFIERS) : - verifiers(max_verifiers) - {} + PublicKeysCache(size_t max_keys = DEFAULT_MAX_KEYS) : keys(max_keys) {} - ccf::crypto::VerifierPtr get_verifier(const DER& der) + ccf::crypto::RSAPublicKeyPtr get_key(const DER& der) { - std::lock_guard guard(verifiers_lock); + std::lock_guard guard(keys_lock); - auto it = verifiers.find(der); - if (it == verifiers.end()) + auto it = keys.find(der); + if (it == keys.end()) { - it = verifiers.insert(der, ccf::crypto::make_unique_verifier(der)); + it = keys.insert(der, ccf::crypto::make_rsa_public_key(der)); } return it->second; @@ -109,7 +108,7 @@ namespace ccf }; JwtAuthnPolicy::JwtAuthnPolicy() : - verifiers(std::make_unique()) + keys_cache(std::make_unique()) {} JwtAuthnPolicy::~JwtAuthnPolicy() = default; @@ -141,11 +140,14 @@ namespace ccf auto fallback_issuers = tx.ro( ccf::Tables::Legacy::JWT_PUBLIC_SIGNING_KEY_ISSUER); - auto fallback_key = fallback_keys->get(key_id); - if (fallback_key) + auto fallback_cert = fallback_keys->get(key_id); + if (fallback_cert) { + // Legacy keys are stored as certs, new approach is raw keys, so + // conversion is needed to implicitly work futher down the code. + auto verifier = ccf::crypto::make_unique_verifier(*fallback_cert); token_keys = std::vector{OpenIDJWKMetadata{ - .cert = *fallback_key, + .public_key = verifier->public_key_der(), .issuer = *fallback_issuers->get(key_id), .constraint = std::nullopt}}; } @@ -160,8 +162,22 @@ namespace ccf for (const auto& metadata : *token_keys) { - auto verifier = verifiers->get_verifier(metadata.cert); - if (!::http::JwtVerifier::validate_token_signature(token, verifier)) + if (!metadata.public_key.has_value()) + { + error_reason = + fmt::format("Missing public key for a given kid: {}", key_id); + continue; + } + + const auto pubkey = keys_cache->get_key(metadata.public_key.value()); + // Obsolote PKCS1 padding is chosen for JWT, as explained in details here: + // https://github.com/microsoft/CCF/issues/6601#issuecomment-2512059875. + if (!pubkey->verify_pkcs1( + (uint8_t*)token.signed_content.data(), + token.signed_content.size(), + token.signature.data(), + token.signature.size(), + ccf::crypto::MDType::SHA256)) { error_reason = "Signature verification failed"; continue; diff --git a/src/node/gov/handlers/service_state.h b/src/node/gov/handlers/service_state.h index c941e14f960a..9fdc53510020 100644 --- a/src/node/gov/handlers/service_state.h +++ b/src/node/gov/handlers/service_state.h @@ -578,11 +578,13 @@ namespace ccf::gov::endpoints { auto info = nlohmann::json::object(); - // cert is stored as DER - convert to PEM for API - const auto cert_pem = - ccf::crypto::cert_der_to_pem(metadata.cert); - info["certificate"] = cert_pem.str(); - + if (metadata.public_key.has_value()) + { + info["publicKey"] = ccf::crypto::make_rsa_public_key( + metadata.public_key.value()) + ->public_key_pem() + .str(); + } info["issuer"] = metadata.issuer; info["constraint"] = metadata.constraint; diff --git a/src/node/rpc/jwt_management.h b/src/node/rpc/jwt_management.h index af7c011ac8ef..ae71cb47bf16 100644 --- a/src/node/rpc/jwt_management.h +++ b/src/node/rpc/jwt_management.h @@ -2,6 +2,7 @@ // Licensed under the Apache 2.0 License. #pragma once +#include "ccf/crypto/rsa_key_pair.h" #include "ccf/crypto/verifier.h" #include "ccf/ds/hex.h" #include "ccf/service/tables/jwt.h" @@ -12,6 +13,64 @@ #include #include +namespace +{ + std::vector try_parse_jwk(const ccf::crypto::JsonWebKeyExtended& jwk) + { + const auto& kid = jwk.kid.value(); + if ( + jwk.e.has_value() && !jwk.e->empty() && jwk.n.has_value() && + !jwk.n->empty()) + { + std::vector der; + ccf::crypto::JsonWebKeyRSAPublic data; + data.kty = ccf::crypto::JsonWebKeyType::RSA; + data.kid = jwk.kid; + data.n = jwk.n.value(); + data.e = jwk.e.value(); + try + { + const auto pubkey = ccf::crypto::make_rsa_public_key(data); + return pubkey->public_key_der(); + } + catch (const std::invalid_argument& exc) + { + throw std::logic_error( + fmt::format("Failed to construct RSA public key: {}", exc.what())); + } + } + else if (jwk.x5c.has_value() && !jwk.x5c->empty()) + { + auto& der_base64 = jwk.x5c.value()[0]; + ccf::Cert der; + try + { + der = ccf::crypto::raw_from_b64(der_base64); + } + catch (const std::invalid_argument& e) + { + throw std::logic_error( + fmt::format("Could not parse x5c of key id {}: {}", kid, e.what())); + } + try + { + auto verifier = ccf::crypto::make_unique_verifier(der); + return verifier->public_key_der(); + } + catch (std::invalid_argument& exc) + { + throw std::logic_error(fmt::format( + "JWKS kid {} has an invalid X.509 certificate: {}", kid, exc.what())); + } + } + else + { + throw std::logic_error( + fmt::format("JWKS kid {} has neither x5c or RSA public key", kid)); + } + } +} + namespace ccf { static void legacy_remove_jwt_public_signing_keys( @@ -37,8 +96,8 @@ namespace ccf const std::string& issuer, const std::string& constraint) { // Only accept key constraints for the same (sub)domain. This is to avoid - // setting keys from issuer A which will be used to validate iss claims for - // issuer B, so this doesn't make sense (at least for now). + // setting keys from issuer A which will be used to validate iss claims + // for issuer B, so this doesn't make sense (at least for now). const auto issuer_domain = ::http::parse_url_full(issuer).host; const auto constraint_domain = ::http::parse_url_full(constraint).host; @@ -48,13 +107,13 @@ namespace ccf return false; } - // Either constraint's domain == issuer's domain or it is a subdomain, e.g.: - // limited.facebook.com + // Either constraint's domain == issuer's domain or it is a subdomain, + // e.g.: limited.facebook.com // .facebook.com // // It may make sense to support vice-versa too, but we haven't found any - // instances of that so far, so leaveing it only-way only for facebook-like - // cases. + // instances of that so far, so leaveing it only-way only for + // facebook-like cases. if (issuer_domain != constraint_domain) { const auto pattern = "." + constraint_domain; @@ -68,8 +127,8 @@ namespace ccf ccf::kv::Tx& tx, std::string issuer) { // Unlike resetting JWT keys for a particular issuer, removing keys can be - // safely done on both table revisions, as soon as the application shouldn't - // use them anyway after being ask about that explicitly. + // safely done on both table revisions, as soon as the application + // shouldn't use them anyway after being ask about that explicitly. legacy_remove_jwt_public_signing_keys(tx, issuer); auto keys = @@ -113,74 +172,45 @@ namespace ccf LOG_FAIL_FMT("{}: JWKS has no keys", log_prefix); return false; } - std::map> new_keys; + std::map new_keys; std::map issuer_constraints; - for (auto& jwk : jwks.keys) - { - if (!jwk.kid.has_value()) - { - LOG_FAIL_FMT("No kid for JWT signing key"); - return false; - } - - if (!jwk.x5c.has_value() && jwk.x5c->empty()) - { - LOG_FAIL_FMT("{}: JWKS is invalid (empty x5c)", log_prefix); - return false; - } - auto& der_base64 = jwk.x5c.value()[0]; - ccf::Cert der; - auto const& kid = jwk.kid.value(); - try - { - der = ccf::crypto::raw_from_b64(der_base64); - } - catch (const std::invalid_argument& e) - { - LOG_FAIL_FMT( - "{}: Could not parse x5c of key id {}: {}", - log_prefix, - kid, - e.what()); - return false; - } - - try - { - ccf::crypto::make_unique_verifier( - (std::vector)der); // throws on error - } - catch (std::invalid_argument& exc) + try + { + for (auto& jwk : jwks.keys) { - LOG_FAIL_FMT( - "{}: JWKS kid {} has an invalid X.509 certificate: {}", - log_prefix, - kid, - exc.what()); - return false; - } + if (!jwk.kid.has_value()) + { + throw(std::logic_error("Missing kid for JWT signing key")); + } - LOG_INFO_FMT("{}: Storing JWT signing key with kid {}", log_prefix, kid); - new_keys.emplace(kid, der); + const auto& kid = jwk.kid.value(); + auto key_der = try_parse_jwk(jwk); - if (jwk.issuer) - { - if (!check_issuer_constraint(issuer, *jwk.issuer)) + if (jwk.issuer) { - LOG_FAIL_FMT( - "{}: JWKS kid {} with issuer constraint {} fails validation " - "against issuer {}", - log_prefix, - kid, - *jwk.issuer, - issuer); - return false; + if (!check_issuer_constraint(issuer, *jwk.issuer)) + { + throw std::logic_error(fmt::format( + "JWKS kid {} with issuer constraint {} fails validation " + "against " + "issuer {}", + kid, + *jwk.issuer, + issuer)); + } + + issuer_constraints.emplace(kid, *jwk.issuer); } - issuer_constraints.emplace(kid, *jwk.issuer); + new_keys.emplace(kid, key_der); } } + catch (const std::exception& exc) + { + LOG_FAIL_FMT("{}: {}", log_prefix, exc.what()); + return false; + } if (new_keys.empty()) { @@ -203,7 +233,10 @@ namespace ccf for (auto& [kid, der] : new_keys) { - OpenIDJWKMetadata value{der, issuer, std::nullopt}; + OpenIDJWKMetadata value{ + .public_key = der, .issuer = issuer, .constraint = std::nullopt}; + value.public_key = der; + const auto it = issuer_constraints.find(kid); if (it != issuer_constraints.end()) { @@ -218,7 +251,7 @@ namespace ccf keys_for_kid->begin(), keys_for_kid->end(), [&value](const auto& metadata) { - return metadata.cert == value.cert && + return metadata.public_key == value.public_key && metadata.issuer == value.issuer && metadata.constraint == value.constraint; }) != keys_for_kid->end()) diff --git a/src/node/rpc/member_frontend.h b/src/node/rpc/member_frontend.h index 438b8709b45f..b8a9620826d2 100644 --- a/src/node/rpc/member_frontend.h +++ b/src/node/rpc/member_frontend.h @@ -67,14 +67,6 @@ namespace ccf DECLARE_JSON_TYPE(JsBundle) DECLARE_JSON_REQUIRED_FIELDS(JsBundle, metadata, modules) - struct KeyIdInfo - { - JwtIssuer issuer; - ccf::crypto::Pem cert; - }; - DECLARE_JSON_TYPE(KeyIdInfo) - DECLARE_JSON_REQUIRED_FIELDS(KeyIdInfo, issuer, cert) - struct FullMemberDetails : public ccf::MemberDetails { ccf::crypto::Pem cert; @@ -1098,30 +1090,6 @@ namespace ccf "5.0.0", "POST /gov/recovery/members/{memberId}:recover") .install(); - using JWTKeyMap = std::map>; - - auto get_jwt_keys = [this](auto& ctx, nlohmann::json&& body) { - auto keys = ctx.tx.ro(network.jwt_public_signing_keys_metadata); - JWTKeyMap kmap; - keys->foreach([&kmap](const auto& k, const auto& v) { - std::vector info; - for (const auto& metadata : v) - { - info.push_back(KeyIdInfo{ - metadata.issuer, ccf::crypto::cert_der_to_pem(metadata.cert)}); - } - kmap.emplace(k, std::move(info)); - return true; - }); - - return make_success(kmap); - }; - make_endpoint( - "/jwt_keys/all", HTTP_GET, json_adapter(get_jwt_keys), no_auth_required) - .set_auto_schema() - .set_openapi_deprecated_replaced("5.0.0", "POST /gov/service/jwk") - .install(); - auto post_proposals_js = [this](ccf::endpoints::EndpointContext& ctx) { std::optional cose_auth_id = std::nullopt; diff --git a/tests/infra/jwt_issuer.py b/tests/infra/jwt_issuer.py index 1882e57b02f5..f3c10b7114db 100644 --- a/tests/infra/jwt_issuer.py +++ b/tests/infra/jwt_issuer.py @@ -107,6 +107,22 @@ def __exit__(self, exc_type, exc_value, traceback): self.stop() +def get_jwt_issuers(args, node): + with node.api_versioned_client(api_version=args.gov_api_version) as c: + r = c.get("/gov/service/jwk") + assert r.status_code == HTTPStatus.OK, r + body = r.body.json() + return body["issuers"] + + +def get_jwt_keys(args, node): + with node.api_versioned_client(api_version=args.gov_api_version) as c: + r = c.get("/gov/service/jwk") + assert r.status_code == HTTPStatus.OK, r + body = r.body.json() + return body["keys"] + + class JwtIssuer: TEST_JWT_ISSUER_NAME = "https://example.issuer" TEST_CA_BUNDLE_NAME = "test_ca_bundle_name" @@ -237,31 +253,26 @@ def wait_for_refresh(self, network, args, kid=None): LOG.warning(body) keys = body["keys"] if kid_ in keys: - stored_cert = keys[kid_][0]["certificate"] - if self.cert_pem == stored_cert: + stored_key = keys[kid_][0]["publicKey"] + if self.key_pub_pem == stored_key: flush_info(logs) return time.sleep(0.1) else: - with primary.client( - network.consortium.get_any_active_member().local_id - ) as c: - while time.time() < end_time: - logs = [] - r = c.get("/gov/jwt_keys/all", log_capture=logs) - assert r.status_code == 200, r - keys = r.body.json() - if kid_ in keys: - kid_vals = keys[kid_] - if primary.version_after("ccf-5.0.0-dev17"): - assert len(kid_vals) == 1 - stored_cert = kid_vals[0]["cert"] - else: - stored_cert = kid_vals["cert"] - if self.cert_pem == stored_cert: - flush_info(logs) - return - time.sleep(0.1) + while time.time() < end_time: + logs = [] + keys = get_jwt_keys(args, primary) + if kid_ in keys: + kid_vals = keys[kid_] + if primary.version_after("ccf-5.0.0-dev17"): + assert len(kid_vals) == 1 + stored_cert = kid_vals[0]["cert"] + else: + stored_cert = kid_vals["cert"] + if self.cert_pem == stored_cert: + flush_info(logs) + return + time.sleep(0.1) flush_info(logs) raise TimeoutError( f"JWT public signing keys were not refreshed after {timeout}s" diff --git a/tests/js-custom-authorization/custom_authorization.py b/tests/js-custom-authorization/custom_authorization.py index 2dbd949e4ec0..7f7aa068687c 100644 --- a/tests/js-custom-authorization/custom_authorization.py +++ b/tests/js-custom-authorization/custom_authorization.py @@ -20,6 +20,8 @@ from http import HTTPStatus import subprocess from contextlib import contextmanager +from cryptography.x509 import load_pem_x509_certificate +from cryptography.hazmat.backends import default_backend from loguru import logger as LOG @@ -102,6 +104,35 @@ def set_issuer_with_a_key(primary, network, issuer, kid, constraint): network.consortium.set_jwt_issuer(primary, metadata_fp.name) +def to_b64(number: int): + as_bytes = number.to_bytes((number.bit_length() + 7) // 8, "big") + return base64.b64encode(as_bytes).decode("ascii") + + +def set_issuer_with_a_raw_key(primary, network, issuer, kid, constraint): + with tempfile.NamedTemporaryFile(prefix="ccf", mode="w+") as metadata_fp: + cert = load_pem_x509_certificate(issuer.cert_pem.encode(), default_backend()) + pubkey = cert.public_key() + data = { + "issuer": issuer.issuer_url, + "auto_refresh": False, + "jwks": { + "keys": [ + { + "kty": "RSA", + "kid": kid, + "n": to_b64(pubkey.public_numbers().n), + "e": to_b64(pubkey.public_numbers().e), + "issuer": constraint, + } + ] + }, + } + json.dump(data, metadata_fp) + metadata_fp.flush() + network.consortium.set_jwt_issuer(primary, metadata_fp.name) + + def parse_error_message(r): return r.body.json()["error"]["details"][0]["message"] @@ -394,6 +425,40 @@ def test_jwt_auth(network, args): return network +@reqs.description("JWT authentication as by OpenID spec with raw public key") +def test_jwt_auth_raw_key(network, args): + primary, _ = network.find_nodes() + + issuer = infra.jwt_issuer.JwtIssuer("https://example.issuer") + + jwt_kid = "my_key_id" + + LOG.info("Add JWT issuer with initial keys") + + set_issuer_with_a_raw_key(primary, network, issuer, jwt_kid, issuer.name) + + LOG.info("Calling jwt endpoint after storing keys") + with primary.client("user0") as c: + r = c.get("/app/jwt", headers=infra.jwt_issuer.make_bearer_header("garbage")) + assert r.status_code == HTTPStatus.UNAUTHORIZED, r.status_code + assert "Malformed JWT" in parse_error_message(r), r + + jwt_mismatching_key_priv_pem, _ = infra.crypto.generate_rsa_keypair(2048) + jwt = infra.crypto.create_jwt({}, jwt_mismatching_key_priv_pem, jwt_kid) + r = c.get("/app/jwt", headers=infra.jwt_issuer.make_bearer_header(jwt)) + assert r.status_code == HTTPStatus.UNAUTHORIZED, r.status_code + assert "JWT payload is missing required field" in parse_error_message(r), r + + r = c.get( + "/app/jwt", + headers=infra.jwt_issuer.make_bearer_header(issuer.issue_jwt(jwt_kid)), + ) + assert r.status_code == HTTPStatus.OK, r.status_code + + network.consortium.remove_jwt_issuer(primary, issuer.name) + return network + + @reqs.description("JWT authentication as by MSFT Entra (single tenant)") def test_jwt_auth_msft_single_tenant(network, args): """For a specific tenant, only tokens with this issuer+tenant can auth.""" @@ -708,6 +773,7 @@ def run_authn(args): network.start_and_open(args) network = test_cert_auth(network, args) network = test_jwt_auth(network, args) + network = test_jwt_auth_raw_key(network, args) network = test_jwt_auth_msft_single_tenant(network, args) network = test_jwt_auth_msft_multitenancy(network, args) network = test_jwt_auth_msft_same_kids_different_issuers(network, args) diff --git a/tests/jwt_test.py b/tests/jwt_test.py index ef0e861fd3f6..afd1cc2d85c3 100644 --- a/tests/jwt_test.py +++ b/tests/jwt_test.py @@ -12,33 +12,16 @@ import infra.e2e_args import infra.proposal import suite.test_requirements as reqs -import infra.jwt_issuer +from infra.jwt_issuer import get_jwt_issuers, get_jwt_keys from infra.runner import ConcurrentRunner import ca_certs import ccf.ledger from ccf.tx_id import TxID import infra.clients -import http from loguru import logger as LOG -def get_jwt_issuers(args, node): - with node.api_versioned_client(api_version=args.gov_api_version) as c: - r = c.get("/gov/service/jwk") - assert r.status_code == http.HTTPStatus.OK, r - body = r.body.json() - return body["issuers"] - - -def get_jwt_keys(args, node): - with node.api_versioned_client(api_version=args.gov_api_version) as c: - r = c.get("/gov/service/jwk") - assert r.status_code == http.HTTPStatus.OK, r - body = r.body.json() - return body["keys"] - - def set_issuer_with_keys(network, primary, issuer, kids): with tempfile.NamedTemporaryFile(prefix="ccf", mode="w+") as metadata_fp: json.dump({"issuer": issuer.name}, metadata_fp) @@ -213,7 +196,7 @@ def test_jwt_endpoint(network, args): assert kid in service_keys, service_keys assert service_keys[kid][0]["issuer"] == issuer.name assert service_keys[kid][0]["constraint"] == issuer.name - assert service_keys[kid][0]["certificate"] == issuer.cert_pem + assert service_keys[kid][0]["publicKey"] == issuer.key_pub_pem @reqs.description("JWT without key policy") @@ -266,9 +249,9 @@ def test_jwt_without_key_policy(network, args): ) keys = get_jwt_keys(args, primary) - stored_cert = keys[kid][0]["certificate"] + stored_key = keys[kid][0]["publicKey"] - assert stored_cert == issuer.cert_pem, "input cert is not equal to stored cert" + assert stored_key == issuer.key_pub_pem, "input key is not equal to stored key" LOG.info("Remove JWT issuer") network.consortium.remove_jwt_issuer(primary, issuer.name) @@ -285,9 +268,9 @@ def test_jwt_without_key_policy(network, args): network.consortium.set_jwt_issuer(primary, metadata_fp.name) keys = get_jwt_keys(args, primary) - stored_cert = keys[kid][0]["certificate"] + stored_key = keys[kid][0]["publicKey"] - assert stored_cert == issuer.cert_pem, "input cert is not equal to stored cert" + assert stored_key == issuer.key_pub_pem, "input key is not equal to stored key" return network @@ -320,18 +303,18 @@ def make_attested_cert(network, args): return pem -def check_kv_jwt_key_matches(args, network, kid, cert_pem): +def check_kv_jwt_key_matches(args, network, kid, key_pem): primary, _ = network.find_nodes() latest_jwt_signing_keys = get_jwt_keys(args, primary) - if cert_pem is None: + if key_pem is None: assert kid not in latest_jwt_signing_keys else: # Necessary to get an AssertionError if the key is not found yet, # when used from with_timeout() assert kid in latest_jwt_signing_keys - stored_cert = latest_jwt_signing_keys[kid][0]["certificate"] - assert stored_cert == cert_pem, "input cert is not equal to stored cert" + stored_key = latest_jwt_signing_keys[kid][0]["publicKey"] + assert stored_key == key_pem, "input cert is not equal to stored cert" def check_kv_jwt_keys_not_empty(args, network, issuer): @@ -405,7 +388,9 @@ def test_jwt_key_auto_refresh(network, args): LOG.info("Check that keys got refreshed") # Note: refresh interval is set to 1s, see network args below. with_timeout( - lambda: check_kv_jwt_key_matches(args, network, kid, issuer.cert_pem), + lambda: check_kv_jwt_key_matches( + args, network, kid, issuer.key_pub_pem + ), timeout=5, ) @@ -438,7 +423,7 @@ def check_has_failures(): with_timeout( lambda: check_kv_jwt_key_matches(args, network, kid, None), timeout=5 ) - check_kv_jwt_key_matches(args, network, kid2, issuer.cert_pem) + check_kv_jwt_key_matches(args, network, kid2, issuer.key_pub_pem) return network @@ -482,7 +467,9 @@ def test_jwt_key_auto_refresh_entries(network, args): LOG.info("Check that keys got refreshed") # Note: refresh interval is set to 1s, see network args below. with_timeout( - lambda: check_kv_jwt_key_matches(args, network, kid, issuer.cert_pem), + lambda: check_kv_jwt_key_matches( + args, network, kid, issuer.key_pub_pem + ), timeout=5, ) @@ -567,7 +554,7 @@ def test_jwt_key_initial_refresh(network, args): # Auto-refresh interval has been set to a large value so that it doesn't happen within the timeout. # This is testing the one-off refresh after adding a new issuer. with_timeout( - lambda: check_kv_jwt_key_matches(args, network, kid, issuer.cert_pem), + lambda: check_kv_jwt_key_matches(args, network, kid, issuer.key_pub_pem), timeout=5, )