Skip to content

Commit

Permalink
Apply review suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
FAlbertDev committed Jul 27, 2023
1 parent 3f7c9c4 commit 908edb1
Show file tree
Hide file tree
Showing 9 changed files with 118 additions and 115 deletions.
21 changes: 11 additions & 10 deletions src/lib/pubkey/pubkey.h
Original file line number Diff line number Diff line change
Expand Up @@ -552,14 +552,17 @@ class KEM_Encapsulation final {
*/
const std::vector<uint8_t>& encapsulated_shared_key() const { return m_encapsulated_shared_key; }

std::vector<uint8_t>& encapsulated_shared_key() { return m_encapsulated_shared_key; }

/**
* @returns the plaintext shared secret
*/
const secure_vector<uint8_t>& shared_key() const { return m_shared_key; }

secure_vector<uint8_t>& shared_key() { return m_shared_key; }
/**
* @returns the pair (encapsulated key, key) extracted from @p kem
*/
static std::pair<std::vector<uint8_t>, secure_vector<uint8_t>> destructure(KEM_Encapsulation&& kem) {
return std::make_pair(std::exchange(kem.m_encapsulated_shared_key, {}), std::exchange(kem.m_shared_key, {}));
}

private:
friend class PK_KEM_Encryptor;
Expand Down Expand Up @@ -644,13 +647,11 @@ class BOTAN_PUBLIC_API(2, 0) PK_KEM_Encryptor final {
KEM_Encapsulation encrypt(RandomNumberGenerator& rng,
size_t desired_shared_key_len = 32,
std::span<const uint8_t> salt = {}) {
KEM_Encapsulation result(encapsulated_key_length(), shared_key_length(desired_shared_key_len));
encrypt(std::span{result.encapsulated_shared_key()},
std::span{result.shared_key()},
rng,
desired_shared_key_len,
salt);
return result;
std::vector<uint8_t> encapsulated_shared_key(encapsulated_key_length());
secure_vector<uint8_t> shared_key(shared_key_length(desired_shared_key_len));

encrypt(std::span{encapsulated_shared_key}, std::span{shared_key}, rng, desired_shared_key_len, salt);
return KEM_Encapsulation(std::move(encapsulated_shared_key), std::move(shared_key));
}

/**
Expand Down
58 changes: 23 additions & 35 deletions src/lib/tls/msg_client_hello.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,27 @@ std::vector<uint8_t> Hello_Request::serialize() const {
return std::vector<uint8_t>();
}

std::unique_ptr<Supported_Groups> Client_Hello_12::make_tls12_supported_groups(const Policy& policy) const {
// RFC 7919 3.
// A client that offers a group MUST be able and willing to perform a DH
// key exchange using that group.
//
// We don't support hybrid key exchange in TLS 1.2
const std::vector<Group_Params> kex_groups = policy.key_exchange_groups();
std::vector<Group_Params> compatible_kex_groups;
std::copy_if(kex_groups.begin(), kex_groups.end(), std::back_inserter(compatible_kex_groups), [](const auto group) {
return is_ecdh(group) || is_dh(group) || is_x25519(group);
});

auto supported_groups = std::make_unique<Supported_Groups>(std::move(compatible_kex_groups));

if(!supported_groups->ec_groups().empty()) {
m_data->extensions().add(new Supported_Point_Formats(policy.use_ecc_point_compression()));
}

return supported_groups;
}

Client_Hello_12::Client_Hello_12(std::unique_ptr<Client_Hello_Internal> data) : Client_Hello(std::move(data)) {
if(offered_suite(static_cast<uint16_t>(TLS_EMPTY_RENEGOTIATION_INFO_SCSV))) {
if(Renegotiation_Extension* reneg = m_data->extensions().get<Renegotiation_Extension>()) {
Expand Down Expand Up @@ -478,23 +499,7 @@ Client_Hello_12::Client_Hello_12(Handshake_IO& io,
m_data->extensions().add(new Certificate_Status_Request({}, {}));
}

// RFC 7919 3.
// A client that offers a group MUST be able and willing to perform a DH
// key exchange using that group.
//
// We don't support hybrid key exchange in TLS 1.2
const std::vector<Group_Params> kex_groups = policy.key_exchange_groups();
std::vector<Group_Params> compatible_kex_groups;
std::copy_if(kex_groups.begin(), kex_groups.end(), std::back_inserter(compatible_kex_groups), [](const auto group) {
return is_ecdh(group) || is_dh(group) || is_x25519(group);
});

auto supported_groups = std::make_unique<Supported_Groups>(std::move(compatible_kex_groups));

if(!supported_groups->ec_groups().empty()) {
m_data->extensions().add(new Supported_Point_Formats(policy.use_ecc_point_compression()));
}
m_data->extensions().add(supported_groups.release());
m_data->extensions().add(make_tls12_supported_groups(policy));

m_data->extensions().add(new Signature_Algorithms(policy.acceptable_signature_schemes()));
if(auto cert_signing_prefs = policy.acceptable_certificate_signature_schemes()) {
Expand Down Expand Up @@ -573,24 +578,7 @@ Client_Hello_12::Client_Hello_12(Handshake_IO& io,
m_data->extensions().add(new Certificate_Status_Request({}, {}));
}

// RFC 7919 3.
// A client that offers a group MUST be able and willing to perform a DH
// key exchange using that group.
//
// We don't support hybrid key exchange in TLS 1.2
const std::vector<Group_Params> kex_groups = policy.key_exchange_groups();
std::vector<Group_Params> compatible_kex_groups;
std::copy_if(kex_groups.begin(), kex_groups.end(), std::back_inserter(compatible_kex_groups), [](const auto group) {
return is_ecdh(group) || is_dh(group) || is_x25519(group);
});

auto supported_groups = std::make_unique<Supported_Groups>(std::move(compatible_kex_groups));

if(!supported_groups->ec_groups().empty()) {
m_data->extensions().add(new Supported_Point_Formats(policy.use_ecc_point_compression()));
}

m_data->extensions().add(supported_groups.release());
m_data->extensions().add(make_tls12_supported_groups(policy));

m_data->extensions().add(new Signature_Algorithms(policy.acceptable_signature_schemes()));
if(auto cert_signing_prefs = policy.acceptable_certificate_signature_schemes()) {
Expand Down
7 changes: 4 additions & 3 deletions src/lib/tls/tls13/tls_extensions_key_share.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,10 @@ class Key_Share_Entry {
const Policy& policy,
Callbacks& cb,
RandomNumberGenerator& rng) {
auto kem_result = cb.tls_kem_encapsulate(m_group, client_share.m_key_exchange, rng, policy);
m_key_exchange = std::move(kem_result.encapsulated_shared_key());
return std::move(kem_result.shared_key());
auto [encapsulated_shared_key, shared_key] =
KEM_Encapsulation::destructure(cb.tls_kem_encapsulate(m_group, client_share.m_key_exchange, rng, policy));
m_key_exchange = std::move(encapsulated_shared_key);
return std::move(shared_key);
}

/**
Expand Down
108 changes: 48 additions & 60 deletions src/lib/tls/tls13_pqc/hybrid_public_key.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,8 @@ namespace Botan::TLS {

namespace {

template <typename RetT, typename KeyT, typename ReducerT>
RetT reduce(const std::vector<KeyT>& keys, RetT acc, ReducerT reducer) {
for(const KeyT& key : keys) {
acc = reducer(std::move(acc), key);
}
return acc;
}

std::vector<std::pair<std::string, std::string>> algorithm_specs_for_group(Group_Params group) {
BOTAN_ASSERT_NOMSG(is_hybrid(group));
BOTAN_ARG_CHECK(is_hybrid(group), "Group is not hybrid");

switch(group) {
case Group_Params::HYBRID_X25519_KYBER_512_R3_OQS:
Expand Down Expand Up @@ -65,9 +57,9 @@ std::vector<AlgorithmIdentifier> algorithm_identifiers_for_group(Group_Params gr
//
// TODO: This is inconvenient, confusing and error-prone. Find a better way
// to load arbitrary public keys.
std::transform(specs.begin(), specs.end(), std::back_inserter(result), [](const auto& spec) {
return AlgorithmIdentifier(spec.second, AlgorithmIdentifier::USE_EMPTY_PARAM);
});
for(const auto& spec : specs) {
result.push_back(AlgorithmIdentifier(spec.second, AlgorithmIdentifier::USE_EMPTY_PARAM));
}

return result;
}
Expand Down Expand Up @@ -130,33 +122,37 @@ Hybrid_KEM_PublicKey::Hybrid_KEM_PublicKey(std::vector<std::unique_ptr<Public_Ke

std::transform(
pks.begin(), pks.end(), std::back_inserter(m_public_keys), [](auto& key) -> std::unique_ptr<Public_Key> {
if(key->supports_operation(PublicKeyOperation::KeyAgreement)) {
if(key->supports_operation(PublicKeyOperation::KeyAgreement) &&
!key->supports_operation(PublicKeyOperation::KeyEncapsulation)) {
return std::make_unique<KEX_to_KEM_Adapter_PublicKey>(std::move(key));
} else {
return std::move(key);
}
});

m_key_length = reduce(m_public_keys, size_t(0), [](size_t kl, const auto& key) { return kl + key->key_length(); });
m_estimated_strength = reduce(
m_public_keys, size_t(0), [](size_t es, const auto& key) { return std::max(es, key->estimated_strength()); });
}

std::string Hybrid_KEM_PublicKey::algo_name() const {
std::string algo_name = "Hybrid(";
std::ostringstream algo_name("Hybrid(");
for(size_t i = 0; i < m_public_keys.size(); ++i) {
algo_name += m_public_keys[i]->algo_name();
if(i < m_public_keys.size() - 1) {
algo_name += ",";
if(i > 0) {
algo_name << ",";
}
algo_name << m_public_keys[i]->algo_name();
}
algo_name += ")";
return algo_name;
algo_name << ")";
return algo_name.str();
}

size_t Hybrid_KEM_PublicKey::estimated_strength() const {
return reduce(
m_public_keys, size_t(0), [](size_t es, const auto& key) { return std::max(es, key->estimated_strength()); });
return m_estimated_strength;
}

size_t Hybrid_KEM_PublicKey::key_length() const {
return reduce(m_public_keys, size_t(0), [](size_t kl, const auto& key) { return kl + key->key_length(); });
return m_key_length;
}

bool Hybrid_KEM_PublicKey::check_key(RandomNumberGenerator& rng, bool strong) const {
Expand Down Expand Up @@ -208,25 +204,18 @@ class Hybrid_KEM_Encryption_Operation final : public PK_Ops::KEM_Encryption_with
Hybrid_KEM_Encryption_Operation(const Hybrid_KEM_PublicKey& key,
std::string_view kdf,
std::string_view provider) :
PK_Ops::KEM_Encryption_with_KDF(kdf) {
std::transform(
key.public_keys().begin(),
key.public_keys().end(),
std::back_inserter(m_kem_encryptors),
[&](const auto& pubkey) { return std::make_unique<PK_KEM_Encryptor>(*pubkey, "Raw", provider); });
PK_Ops::KEM_Encryption_with_KDF(kdf), m_raw_kem_shared_key_length(0), m_encapsulated_key_length(0) {
for(const auto& k : key.public_keys()) {
auto kem = std::make_unique<PK_KEM_Encryptor>(*k, "Raw", provider);
m_raw_kem_shared_key_length += kem->shared_key_length(0 /* no KDF */);
m_encapsulated_key_length += kem->encapsulated_key_length();
m_kem_encryptors.push_back(std::move(kem));
}
}

size_t raw_kem_shared_key_length() const override {
return reduce(m_kem_encryptors, size_t(0), [](auto acc, const auto& kem_enc) {
return acc + kem_enc->shared_key_length(0 /* no KDF */);
});
}
size_t raw_kem_shared_key_length() const override { return m_raw_kem_shared_key_length; }

size_t encapsulated_key_length() const override {
return reduce(m_kem_encryptors, size_t(0), [](auto acc, const auto& kem_enc) {
return acc + kem_enc->encapsulated_key_length();
});
}
size_t encapsulated_key_length() const override { return m_encapsulated_key_length; }

void raw_kem_encrypt(std::span<uint8_t> out_encapsulated_key,
std::span<uint8_t> raw_shared_key,
Expand All @@ -248,6 +237,8 @@ class Hybrid_KEM_Encryption_Operation final : public PK_Ops::KEM_Encryption_with
// Note: PK_KEM_Encryptor can neither be moved nor copied. Hence, we wrap
// it into a std::unique_ptr<> before storing it in a variant/vector.
std::vector<std::unique_ptr<PK_KEM_Encryptor>> m_kem_encryptors;
size_t m_raw_kem_shared_key_length;
size_t m_encapsulated_key_length;
};

} // namespace
Expand All @@ -261,11 +252,10 @@ namespace {

auto extract_public_keys(const std::vector<std::unique_ptr<Private_Key>>& private_keys) {
std::vector<std::unique_ptr<Public_Key>> public_keys;
std::transform(
private_keys.begin(), private_keys.end(), std::back_inserter(public_keys), [](const auto& private_key) {
BOTAN_ARG_CHECK(private_key != nullptr, "List of private keys contains a nullptr");
return private_key->public_key();
});
for(const auto& private_key : private_keys) {
BOTAN_ARG_CHECK(private_key != nullptr, "List of private keys contains a nullptr");
public_keys.push_back(private_key->public_key());
}
return public_keys;
}

Expand All @@ -274,10 +264,10 @@ auto extract_public_keys(const std::vector<std::unique_ptr<Private_Key>>& privat
std::unique_ptr<Hybrid_KEM_PrivateKey> Hybrid_KEM_PrivateKey::generate_from_group(Group_Params group,
RandomNumberGenerator& rng) {
const auto algo_spec = algorithm_specs_for_group(group);
std::vector<std::unique_ptr<Private_Key>> private_keys;
std::transform(algo_spec.begin(), algo_spec.end(), std::back_inserter(private_keys), [&](const auto& spec) {
return create_private_key(spec.first, rng, spec.second);
});
std::vector<std::unique_ptr<Private_Key>> private_keys(algo_spec.size());
for(const auto& spec : algo_spec) {
private_keys.push_back(create_private_key(spec.first, rng, spec.second));
}
return std::make_unique<Hybrid_KEM_PrivateKey>(std::move(private_keys));
}

Expand All @@ -294,7 +284,8 @@ Hybrid_KEM_PrivateKey::Hybrid_KEM_PrivateKey(std::vector<std::unique_ptr<Private

std::transform(
sks.begin(), sks.end(), std::back_inserter(m_private_keys), [](auto& key) -> std::unique_ptr<Private_Key> {
if(key->supports_operation(PublicKeyOperation::KeyAgreement)) {
if(key->supports_operation(PublicKeyOperation::KeyAgreement) &&
!key->supports_operation(PublicKeyOperation::KeyEncapsulation)) {
auto ka_key = dynamic_cast<PK_Key_Agreement_Key*>(key.get());
BOTAN_ASSERT_NONNULL(ka_key);
(void)key.release();
Expand All @@ -310,10 +301,10 @@ secure_vector<uint8_t> Hybrid_KEM_PrivateKey::private_key_bits() const {
}

std::unique_ptr<Public_Key> Hybrid_KEM_PrivateKey::public_key() const {
std::vector<std::unique_ptr<Public_Key>> pks;
std::transform(m_private_keys.cbegin(), m_private_keys.cend(), std::back_inserter(pks), [](const auto& sk) {
return sk->public_key();
});
std::vector<std::unique_ptr<Public_Key>> pks(m_private_keys.size());
for(const auto& sk : m_private_keys) {
pks.push_back(sk->public_key());
}
return std::make_unique<Hybrid_KEM_PublicKey>(std::move(pks));
}

Expand All @@ -330,14 +321,11 @@ class Hybrid_KEM_Decryption final : public PK_Ops::KEM_Decryption_with_KDF {
const std::string_view kdf,
const std::string_view provider) :
PK_Ops::KEM_Decryption_with_KDF(kdf) {
std::transform(key.private_keys().begin(),
key.private_keys().end(),
std::back_inserter(m_decryptors_with_encapsulation_lengths),
[&](const auto& private_key) {
PK_KEM_Encryptor enc(*private_key, "Raw");
return std::pair{std::make_unique<PK_KEM_Decryptor>(*private_key, rng, "Raw", provider),
enc.encapsulated_key_length()};
});
for(const auto& private_key : key.private_keys()) {
PK_KEM_Encryptor enc(*private_key, "Raw");
m_decryptors_with_encapsulation_lengths.push_back(std::pair{
std::make_unique<PK_KEM_Decryptor>(*private_key, rng, "Raw", provider), enc.encapsulated_key_length()});
}
}

void raw_kem_decrypt(std::span<uint8_t> out_shared_key, std::span<const uint8_t> encap_key) override {
Expand Down
10 changes: 7 additions & 3 deletions src/lib/tls/tls13_pqc/hybrid_public_key.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ namespace Botan::TLS {
* Composes a number of public keys as defined in this IETF draft:
* https://datatracker.ietf.org/doc/html/draft-ietf-tls-hybrid-design-04
*
* To an upstream user this composite key pair is presented as a KEM.
* To an upstream user, this composite key pair is presented as a KEM.
* Compositions of at least two (and potentially more) public keys are legal.
* Each individual key pair must either work as a KEX or as a KEM. Currently,
* the class can deal with ECC keys anc Kyber.
* the class can deal with ECC keys and Kyber.
*
* Note that this class is not generic enough for arbitrary use cases but
* serializes and parses keys and ciphertexts as described in above-mentioned
* serializes and parses keys and ciphertexts as described in the above-mentioned
* IETF draft for a post-quantum TLS 1.3.
*/
class BOTAN_TEST_API Hybrid_KEM_PublicKey : public virtual Public_Key {
Expand Down Expand Up @@ -65,6 +65,10 @@ class BOTAN_TEST_API Hybrid_KEM_PublicKey : public virtual Public_Key {

protected:
std::vector<std::unique_ptr<Public_Key>> m_public_keys;

private:
size_t m_key_length;
size_t m_estimated_strength;
};

BOTAN_DIAGNOSTIC_PUSH
Expand Down
4 changes: 2 additions & 2 deletions src/lib/tls/tls_algos.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ constexpr bool is_dh(const Group_Params group) {
group == Group_Params::FFDHE_6144 || group == Group_Params::FFDHE_8192;
}

constexpr bool is_kyber(const Group_Params group) {
constexpr bool is_pure_kyber(const Group_Params group) {
return group == Group_Params::KYBER_512_R3 || group == Group_Params::KYBER_768_R3 ||
group == Group_Params::KYBER_1024_R3;
}
Expand All @@ -150,7 +150,7 @@ constexpr bool is_hybrid(const Group_Params group) {
}

constexpr bool is_kem(const Group_Params group) {
return is_kyber(group) || is_hybrid(group);
return is_pure_kyber(group) || is_hybrid(group);
}

std::string group_param_to_string(Group_Params group);
Expand Down
4 changes: 2 additions & 2 deletions src/lib/tls/tls_callbacks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ bool TLS::Callbacks::tls_verify_message(const Public_Key& key,

std::unique_ptr<Private_Key> TLS::Callbacks::tls_kem_generate_key(TLS::Group_Params group, RandomNumberGenerator& rng) {
#if defined(BOTAN_HAS_KYBER)
if(is_kyber(group)) {
if(is_pure_kyber(group)) {
return std::make_unique<Kyber_PrivateKey>(rng, KyberMode(group_param_to_string(group)));
}
#endif
Expand Down Expand Up @@ -170,7 +170,7 @@ KEM_Encapsulation TLS::Callbacks::tls_kem_encapsulate(TLS::Group_Params group,
#endif

#if defined(BOTAN_HAS_KYBER)
if(is_kyber(group)) {
if(is_pure_kyber(group)) {
return std::make_unique<Kyber_PublicKey>(encoded_public_key, KyberMode(group_param_to_string(group)));
}
#endif
Expand Down
Loading

0 comments on commit 908edb1

Please sign in to comment.