Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some clang-tidy warnings in the TLS 1.3 PSK extension code #3619

Merged
merged 1 commit into from
Jul 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/lib/tls/tls13/tls_cipher_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,8 @@ bool Cipher_State::is_compatible_with(const Ciphersuite& cipher) const {
return true;
}

std::vector<uint8_t> Cipher_State::psk_binder_mac(const Transcript_Hash& transcript_hash_with_truncated_client_hello) {
std::vector<uint8_t> Cipher_State::psk_binder_mac(
const Transcript_Hash& transcript_hash_with_truncated_client_hello) const {
BOTAN_ASSERT_NOMSG(m_state == State::PskBinder);

auto hmac = HMAC(m_hash->new_object());
Expand Down
2 changes: 1 addition & 1 deletion src/lib/tls/tls13/tls_cipher_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class BOTAN_TEST_API Cipher_State {
* the transcript hash passed into this method is computed from a partial
* Client Hello (RFC 8446 4.2.11.2)
*/
std::vector<uint8_t> psk_binder_mac(const Transcript_Hash& transcript_hash_with_truncated_client_hello);
std::vector<uint8_t> psk_binder_mac(const Transcript_Hash& transcript_hash_with_truncated_client_hello) const;

/**
* Calculate the MAC for a TLS "Finished" handshake message (RFC 8446 4.4.4)
Expand Down
97 changes: 62 additions & 35 deletions src/lib/tls/tls13/tls_extensions_psk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ decltype(auto) calculate_age(std::chrono::system_clock::time_point then, std::ch
return std::chrono::duration_cast<std::chrono::milliseconds>(now - then);
}

struct Client_PSK {
class Client_PSK {
public:
Client_PSK(Session_with_Handle& session_to_resume, std::chrono::system_clock::time_point timestamp) :
Client_PSK(PskIdentity(session_to_resume.handle.opaque_handle(),
calculate_age(session_to_resume.session.start_time(), timestamp),
Expand All @@ -42,13 +43,13 @@ struct Client_PSK {
session_to_resume.session.extract_master_secret(),
Cipher_State::PSK_Type::Resumption) {}

Client_PSK(PskIdentity id, std::vector<uint8_t> bndr) : identity(std::move(id)), binder(std::move(bndr)) {}
Client_PSK(PskIdentity id, std::vector<uint8_t> bndr) : m_identity(std::move(id)), m_binder(std::move(bndr)) {}

Client_PSK(PskIdentity id,
std::string_view prf_algo,
secure_vector<uint8_t>&& master_secret,
Cipher_State::PSK_Type psk_type) :
identity(std::move(id)),
m_identity(std::move(id)),

// RFC 8446 4.2.11.2
// Each entry in the binders list is computed as an HMAC over a transcript
Expand All @@ -63,24 +64,53 @@ struct Client_PSK {
// Hence, we fill the binders with dummy values of the correct length and use
// `Client_Hello_13::truncate()` to split them off before calculating the
// transcript hash that underpins the PSK binders. S.a. `calculate_binders()`
binder(HashFunction::create_or_throw(prf_algo)->output_length()),
cipher_state(
m_binder(HashFunction::create_or_throw(prf_algo)->output_length()),
m_cipher_state(
Cipher_State::init_with_psk(Connection_Side::Client, psk_type, std::move(master_secret), prf_algo)) {}

PskIdentity identity;
std::vector<uint8_t> binder;
const PskIdentity& identity() const { return m_identity; }

const std::vector<uint8_t>& binder() const { return m_binder; }

void set_binder(std::vector<uint8_t> binder) { m_binder = std::move(binder); }

const Cipher_State& cipher_state() const {
BOTAN_ASSERT_NONNULL(m_cipher_state);
return *m_cipher_state;
}

std::unique_ptr<Cipher_State> take_cipher_state() { return std::move(m_cipher_state); }

private:
PskIdentity m_identity;
std::vector<uint8_t> m_binder;

// Clients set up associated cipher states for PSKs
// Servers leave this as nullptr
std::unique_ptr<Cipher_State> cipher_state;
std::unique_ptr<Cipher_State> m_cipher_state;
};

struct Server_PSK {
uint16_t selected_identity;
class Server_PSK {
public:
Server_PSK(uint16_t id) : m_selected_identity(id), m_session_to_resume(std::nullopt) {}

Server_PSK(uint16_t id, Session session) : m_selected_identity(id), m_session_to_resume(std::move(session)) {}

uint16_t selected_identity() const { return m_selected_identity; }

Session take_session_to_resume() {
BOTAN_STATE_CHECK(m_session_to_resume.has_value());
Session s = std::move(m_session_to_resume.value());
m_session_to_resume = std::nullopt;
return s;
}

private:
uint16_t m_selected_identity;

// Servers store the Session to resume from the selected PSK
// Clients leave this as std::nullopt
std::optional<Session> session_to_resume;
std::optional<Session> m_session_to_resume;
};

} // namespace
Expand All @@ -101,8 +131,8 @@ PSK::PSK(TLS_Data_Reader& reader, uint16_t extension_size, Handshake_Type messag
throw TLS_Exception(Alert::DecodeError, "Server provided a malformed PSK extension");
}

m_impl = std::make_unique<PSK_Internal>(
Server_PSK{.selected_identity = reader.get_uint16_t(), .session_to_resume = std::nullopt});
const uint16_t selected_id = reader.get_uint16_t();
m_impl = std::make_unique<PSK_Internal>(Server_PSK(selected_id));
} else if(message_type == Handshake_Type::ClientHello) {
const auto identities_length = reader.get_uint16_t();
const auto identities_offset = reader.read_so_far();
Expand Down Expand Up @@ -155,8 +185,7 @@ PSK::PSK(Session_with_Handle& session_to_resume, Callbacks& callbacks) {
}

PSK::PSK(Session session_to_resume, const uint16_t psk_index) :
m_impl(std::make_unique<PSK_Internal>(
Server_PSK{.selected_identity = psk_index, .session_to_resume = std::move(session_to_resume)})) {}
m_impl(std::make_unique<PSK_Internal>(Server_PSK(psk_index, std::move(session_to_resume)))) {}

PSK::~PSK() = default;

Expand All @@ -173,7 +202,7 @@ std::unique_ptr<Cipher_State> PSK::select_cipher_state(const PSK& server_psk, co
BOTAN_STATE_CHECK(std::holds_alternative<std::vector<Client_PSK>>(m_impl->psk));
BOTAN_STATE_CHECK(std::holds_alternative<Server_PSK>(server_psk.m_impl->psk));

const auto id = std::get<Server_PSK>(server_psk.m_impl->psk).selected_identity;
const auto id = std::get<Server_PSK>(server_psk.m_impl->psk).selected_identity();
auto& ids = std::get<std::vector<Client_PSK>>(m_impl->psk);

// RFC 8446 4.2.11
Expand All @@ -185,7 +214,7 @@ std::unique_ptr<Cipher_State> PSK::select_cipher_state(const PSK& server_psk, co
throw TLS_Exception(Alert::IllegalParameter, "PSK identity selected by server is out of bounds");
}

auto cipher_state = std::exchange(ids[id].cipher_state, nullptr);
auto cipher_state = ids[id].take_cipher_state();
BOTAN_ASSERT_NONNULL(cipher_state);

// destroy cipher states and PSKs that were not selected by the server
Expand All @@ -212,7 +241,7 @@ std::unique_ptr<PSK> PSK::select_offered_psk(const Ciphersuite& cipher,
auto& psks = std::get<std::vector<Client_PSK>>(m_impl->psk);
std::vector<PskIdentity> psk_identities;
std::transform(
psks.begin(), psks.end(), std::back_inserter(psk_identities), [&](const auto& psk) { return psk.identity; });
psks.begin(), psks.end(), std::back_inserter(psk_identities), [&](const auto& psk) { return psk.identity(); });

if(auto selection = session_mgr.choose_from_offered_tickets(psk_identities, cipher.prf_algo(), callbacks, policy)) {
auto& [session, psk_index] = selection.value();
Expand All @@ -237,19 +266,15 @@ void PSK::filter(const Ciphersuite& cipher) {
auto& psks = std::get<std::vector<Client_PSK>>(m_impl->psk);

const auto r = std::remove_if(psks.begin(), psks.end(), [&](const auto& psk) {
BOTAN_ASSERT_NONNULL(psk.cipher_state);
return !psk.cipher_state->is_compatible_with(cipher);
const auto& cipher_state = psk.cipher_state();
return !cipher_state.is_compatible_with(cipher);
});
psks.erase(r, psks.end());
}

Session PSK::take_session_to_resume() {
BOTAN_STATE_CHECK(std::holds_alternative<Server_PSK>(m_impl->psk));
auto& session_to_resume = std::get<Server_PSK>(m_impl->psk).session_to_resume;
BOTAN_STATE_CHECK(session_to_resume.has_value());
Session s = std::move(session_to_resume.value());
session_to_resume = std::nullopt;
return s;
return std::get<Server_PSK>(m_impl->psk).take_session_to_resume();
}

std::vector<uint8_t> PSK::serialize(Connection_Side side) const {
Expand All @@ -259,24 +284,26 @@ std::vector<uint8_t> PSK::serialize(Connection_Side side) const {
[&](const Server_PSK& psk) {
BOTAN_STATE_CHECK(side == Connection_Side::Server);
result.reserve(2);
result.push_back(get_byte<0>(psk.selected_identity));
result.push_back(get_byte<1>(psk.selected_identity));
const uint16_t id = psk.selected_identity();
result.push_back(get_byte<0>(id));
result.push_back(get_byte<1>(id));
},
[&](const std::vector<Client_PSK>& psks) {
BOTAN_STATE_CHECK(side == Connection_Side::Client);

std::vector<uint8_t> identities;
std::vector<uint8_t> binders;
for(const auto& psk : psks) {
append_tls_length_value(identities, psk.identity.identity(), 2);
const auto& psk_identity = psk.identity();
append_tls_length_value(identities, psk_identity.identity(), 2);

const auto obfuscated_ticket_age = psk.identity.obfuscated_age();
const uint32_t obfuscated_ticket_age = psk_identity.obfuscated_age();
identities.push_back(get_byte<0>(obfuscated_ticket_age));
identities.push_back(get_byte<1>(obfuscated_ticket_age));
identities.push_back(get_byte<2>(obfuscated_ticket_age));
identities.push_back(get_byte<3>(obfuscated_ticket_age));

append_tls_length_value(binders, psk.binder, 1);
append_tls_length_value(binders, psk.binder(), 1);
}

append_tls_length_value(result, identities, 2);
Expand All @@ -293,21 +320,21 @@ void PSK::calculate_binders(const Transcript_Hash_State& truncated_transcript_ha
BOTAN_ASSERT_NOMSG(std::holds_alternative<std::vector<Client_PSK>>(m_impl->psk));
for(auto& psk : std::get<std::vector<Client_PSK>>(m_impl->psk)) {
auto tth = truncated_transcript_hash.clone();
BOTAN_ASSERT_NONNULL(psk.cipher_state);
tth.set_algorithm(psk.cipher_state->hash_algorithm());
psk.binder = psk.cipher_state->psk_binder_mac(tth.truncated());
const auto& cipher_state = psk.cipher_state();
tth.set_algorithm(cipher_state.hash_algorithm());
psk.set_binder(cipher_state.psk_binder_mac(tth.truncated()));
}
}

bool PSK::validate_binder(const PSK& server_psk, const std::vector<uint8_t>& binder) const {
BOTAN_STATE_CHECK(std::holds_alternative<std::vector<Client_PSK>>(m_impl->psk));
BOTAN_STATE_CHECK(std::holds_alternative<Server_PSK>(server_psk.m_impl->psk));

const auto index = std::get<Server_PSK>(server_psk.m_impl->psk).selected_identity;
const uint16_t index = std::get<Server_PSK>(server_psk.m_impl->psk).selected_identity();
const auto& psks = std::get<std::vector<Client_PSK>>(m_impl->psk);

BOTAN_STATE_CHECK(index < psks.size());
return psks[index].binder == binder;
return psks[index].binder() == binder;
}

} // namespace Botan::TLS
Expand Down