Skip to content

Commit

Permalink
Add helper for making EC key from components (#358)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomhughes authored Aug 9, 2024
1 parent a968dfe commit c870903
Show file tree
Hide file tree
Showing 3 changed files with 453 additions and 2 deletions.
242 changes: 241 additions & 1 deletion include/jwt-cpp/jwt.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ namespace jwt {
get_key_failed,
write_key_failed,
write_cert_failed,
convert_to_pem_failed
convert_to_pem_failed,
unknown_curve,
set_ecdsa_failed
};
/**
* \brief Error category for ECDSA errors
Expand All @@ -194,6 +196,8 @@ namespace jwt {
case ecdsa_error::write_key_failed: return "error writing key data in PEM format";
case ecdsa_error::write_cert_failed: return "error writing cert data in PEM format";
case ecdsa_error::convert_to_pem_failed: return "failed to convert key to pem";
case ecdsa_error::unknown_curve: return "unknown curve";
case ecdsa_error::set_ecdsa_failed: return "set parameters to ECDSA failed";
default: return "unknown ECDSA error";
}
}
Expand Down Expand Up @@ -1085,6 +1089,242 @@ namespace jwt {
error::throw_if_error(ec);
return res;
}

#if defined(JWT_OPENSSL_3_0)

/**
* \brief Convert a curve name to a group name.
*
* \param curve string containing curve name
* \param ec error_code for error_detection
* \return group name
*/
inline std::string curve2group(const std::string curve, std::error_code& ec) {
if (curve == "P-256") {
return "prime256v1";
} else if (curve == "P-384") {
return "secp384r1";
} else if (curve == "P-521") {
return "secp521r1";
} else {
ec = jwt::error::ecdsa_error::unknown_curve;
return {};
}
}

#else

/**
* \brief Convert a curve name to an ID.
*
* \param curve string containing curve name
* \param ec error_code for error_detection
* \return ID
*/
inline int curve2nid(const std::string curve, std::error_code& ec) {
if (curve == "P-256") {
return NID_X9_62_prime256v1;
} else if (curve == "P-384") {
return NID_secp384r1;
} else if (curve == "P-521") {
return NID_secp521r1;
} else {
ec = jwt::error::ecdsa_error::unknown_curve;
return {};
}
}

#endif

/**
* Create public key from curve name and coordinates. This is defined in
* [RFC 7518 Section 6.2](https://www.rfc-editor.org/rfc/rfc7518#section-6.2)
* Using the required "crv" (Curve), "x" (X Coordinate) and "y" (Y Coordinate) Parameters.
*
* \tparam Decode is callable, taking a string_type and returns a string_type.
* It should ensure the padding of the input and then base64url decode and
* return the results.
* \param curve string containing curve name
* \param x string containing base64url encoded x coordinate
* \param y string containing base64url encoded y coordinate
* \param decode The function to decode the RSA parameters
* \param ec error_code for error_detection (gets cleared if no error occur
* \return public key in PEM format
*/
template<typename Decode>
std::string create_public_key_from_ec_components(const std::string& curve, const std::string& x,
const std::string& y, Decode decode, std::error_code& ec) {
ec.clear();
auto decoded_x = decode(x);
auto decoded_y = decode(y);

#if defined(JWT_OPENSSL_3_0)
// OpenSSL deprecated mutable keys and there is a new way for making them
// https://mta.openssl.org/pipermail/openssl-users/2021-July/013994.html
// https://www.openssl.org/docs/man3.1/man3/OSSL_PARAM_BLD_new.html#Example-2
std::unique_ptr<OSSL_PARAM_BLD, decltype(&OSSL_PARAM_BLD_free)> param_bld(OSSL_PARAM_BLD_new(),
OSSL_PARAM_BLD_free);
if (!param_bld) {
ec = error::ecdsa_error::create_context_failed;
return {};
}

std::string group = helper::curve2group(curve, ec);
if (ec) return {};

// https://github.com/openssl/openssl/issues/16270#issuecomment-895734092
std::string pub = std::string("\x04").append(decoded_x).append(decoded_y);

if (OSSL_PARAM_BLD_push_utf8_string(param_bld.get(), "group", group.data(), group.size()) != 1 ||
OSSL_PARAM_BLD_push_octet_string(param_bld.get(), "pub", pub.data(), pub.size()) != 1) {
ec = error::ecdsa_error::set_ecdsa_failed;
return {};
}

std::unique_ptr<OSSL_PARAM, decltype(&OSSL_PARAM_free)> params(OSSL_PARAM_BLD_to_param(param_bld.get()),
OSSL_PARAM_free);
if (!params) {
ec = error::ecdsa_error::set_ecdsa_failed;
return {};
}

std::unique_ptr<EVP_PKEY_CTX, decltype(&EVP_PKEY_CTX_free)> ctx(
EVP_PKEY_CTX_new_from_name(nullptr, "EC", nullptr), EVP_PKEY_CTX_free);
if (!ctx) {
ec = error::ecdsa_error::create_context_failed;
return {};
}

// https://www.openssl.org/docs/man3.0/man3/EVP_PKEY_fromdata.html#EXAMPLES
// Error codes based on https://www.openssl.org/docs/manmaster/man3/EVP_PKEY_fromdata_init.html#RETURN-VALUES
EVP_PKEY* pkey = NULL;
if (EVP_PKEY_fromdata_init(ctx.get()) <= 0 ||
EVP_PKEY_fromdata(ctx.get(), &pkey, EVP_PKEY_KEYPAIR, params.get()) <= 0) {
// It's unclear if this can fail after allocating but free it anyways
// https://www.openssl.org/docs/man3.0/man3/EVP_PKEY_fromdata.html
EVP_PKEY_free(pkey);

ec = error::ecdsa_error::cert_load_failed;
return {};
}

// Transfer ownership so we get ref counter and cleanup
evp_pkey_handle ecdsa(pkey);

#else
int nid = helper::curve2nid(curve, ec);
if (ec) return {};

auto qx = helper::raw2bn(decoded_x, ec);
if (ec) return {};
auto qy = helper::raw2bn(decoded_y, ec);
if (ec) return {};

std::unique_ptr<EC_GROUP, decltype(&EC_GROUP_free)> ecgroup(EC_GROUP_new_by_curve_name(nid), EC_GROUP_free);
if (!ecgroup) {
ec = error::ecdsa_error::set_ecdsa_failed;
return {};
}

EC_GROUP_set_asn1_flag(ecgroup.get(), OPENSSL_EC_NAMED_CURVE);

std::unique_ptr<EC_POINT, decltype(&EC_POINT_free)> ecpoint(EC_POINT_new(ecgroup.get()), EC_POINT_free);
if (!ecpoint ||
EC_POINT_set_affine_coordinates_GFp(ecgroup.get(), ecpoint.get(), qx.get(), qy.get(), nullptr) != 1) {
ec = error::ecdsa_error::set_ecdsa_failed;
return {};
}

std::unique_ptr<EC_KEY, decltype(&EC_KEY_free)> ecdsa(EC_KEY_new(), EC_KEY_free);
if (!ecdsa || EC_KEY_set_group(ecdsa.get(), ecgroup.get()) != 1 ||
EC_KEY_set_public_key(ecdsa.get(), ecpoint.get()) != 1) {
ec = error::ecdsa_error::set_ecdsa_failed;
return {};
}

#endif

auto pub_key_bio = make_mem_buf_bio();
if (!pub_key_bio) {
ec = error::ecdsa_error::create_mem_bio_failed;
return {};
}

auto write_pem_to_bio =
#if defined(JWT_OPENSSL_3_0)
// https://www.openssl.org/docs/man3.1/man3/PEM_write_bio_EC_PUBKEY.html
&PEM_write_bio_PUBKEY;
#else
&PEM_write_bio_EC_PUBKEY;
#endif
if (write_pem_to_bio(pub_key_bio.get(), ecdsa.get()) != 1) {
ec = error::ecdsa_error::load_key_bio_write;
return {};
}

return write_bio_to_string<error::ecdsa_error>(pub_key_bio, ec);
}

/**
* Create public key from curve name and coordinates. This is defined in
* [RFC 7518 Section 6.2](https://www.rfc-editor.org/rfc/rfc7518#section-6.2)
* Using the required "crv" (Curve), "x" (X Coordinate) and "y" (Y Coordinate) Parameters.
*
* \tparam Decode is callable, taking a string_type and returns a string_type.
* It should ensure the padding of the input and then base64url decode and
* return the results.
* \param curve string containing curve name
* \param x string containing base64url encoded x coordinate
* \param y string containing base64url encoded y coordinate
* \param decode The function to decode the RSA parameters
* \return public key in PEM format
*/
template<typename Decode>
std::string create_public_key_from_ec_components(const std::string& curve, const std::string& x,
const std::string& y, Decode decode) {
std::error_code ec;
auto res = create_public_key_from_ec_components(curve, x, y, decode, ec);
error::throw_if_error(ec);
return res;
}

#ifndef JWT_DISABLE_BASE64
/**
* Create public key from curve name and coordinates. This is defined in
* [RFC 7518 Section 6.2](https://www.rfc-editor.org/rfc/rfc7518#section-6.2)
* Using the required "crv" (Curve), "x" (X Coordinate) and "y" (Y Coordinate) Parameters.
*
* \param curve string containing curve name
* \param x string containing base64url encoded x coordinate
* \param y string containing base64url encoded y coordinate
* \param ec error_code for error_detection (gets cleared if no error occur
* \return public key in PEM format
*/
inline std::string create_public_key_from_ec_components(const std::string& curve, const std::string& x,
const std::string& y, std::error_code& ec) {
auto decode = [](const std::string& token) {
return base::decode<alphabet::base64url>(base::pad<alphabet::base64url>(token));
};
return create_public_key_from_ec_components(curve, x, y, std::move(decode), ec);
}
/**
* Create public key from curve name and coordinates. This is defined in
* [RFC 7518 Section 6.2](https://www.rfc-editor.org/rfc/rfc7518#section-6.2)
* Using the required "crv" (Curve), "x" (X Coordinate) and "y" (Y Coordinate) Parameters.
*
* \param curve string containing curve name
* \param x string containing base64url encoded x coordinate
* \param y string containing base64url encoded y coordinate
* \return public key in PEM format
*/
inline std::string create_public_key_from_ec_components(const std::string& curve, const std::string& x,
const std::string& y) {
std::error_code ec;
auto res = create_public_key_from_ec_components(curve, x, y, ec);
error::throw_if_error(ec);
return res;
}
#endif
} // namespace helper

/**
Expand Down
19 changes: 18 additions & 1 deletion tests/HelperTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,23 @@ JQIDAQAB
ASSERT_EQ(public_key, public_key_expected);
}

TEST(HelperTest, EcFromComponents) {
const std::string public_key_expected =
R"(-----BEGIN PUBLIC KEY-----
MHYwEAYHKoZIzj0CAQYFK4EEACIDYgAE0uQ1+1P/wmhOuYvVtTogHOSBLC05IvK7
L6sTPIX8Dl4Bg9nhC3v/FsgifjnXnijUxVJSyWa9SuxwBonUhg6SiCEv+ixb74hj
DesC4D7OwllVcnkDJmOy/NMx4N7yDPJp
-----END PUBLIC KEY-----
)";
const std::string curve = R"(P-384)";
const std::string x = R"(0uQ1-1P_wmhOuYvVtTogHOSBLC05IvK7L6sTPIX8Dl4Bg9nhC3v_FsgifjnXnijU)";
const std::string y = R"(xVJSyWa9SuxwBonUhg6SiCEv-ixb74hjDesC4D7OwllVcnkDJmOy_NMx4N7yDPJp)";

const auto public_key = jwt::helper::create_public_key_from_ec_components(curve, x, y);

ASSERT_EQ(public_key, public_key_expected);
}

TEST(HelperTest, ErrorCodeMessages) {
ASSERT_EQ(std::error_code(jwt::error::rsa_error::ok).message(), "no error");
ASSERT_EQ(std::error_code(static_cast<jwt::error::rsa_error>(-1)).message(), "unknown RSA error");
Expand Down Expand Up @@ -80,7 +97,7 @@ TEST(HelperTest, ErrorCodeMessages) {
ASSERT_EQ(std::error_code(static_cast<jwt::error::rsa_error>(i)).message(),
std::error_code(static_cast<jwt::error::rsa_error>(-1)).message());

for (i = 10; i < 22; i++) {
for (i = 10; i < 24; i++) {
ASSERT_NE(std::error_code(static_cast<jwt::error::ecdsa_error>(i)).message(),
std::error_code(static_cast<jwt::error::ecdsa_error>(-1)).message());
}
Expand Down
Loading

0 comments on commit c870903

Please sign in to comment.