Skip to content

Commit

Permalink
templated handshake states
Browse files Browse the repository at this point in the history
Co-authored-by: René Meusel <rene.meusel@nexenio.com>
  • Loading branch information
Hannes Rantzsch and reneme committed Feb 23, 2022
1 parent 28d940a commit 91d05f0
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 36 deletions.
1 change: 0 additions & 1 deletion src/lib/tls/tls13/tls_channel_impl_13.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ namespace TLS {

class Connection_Sequence_Numbers;
class Connection_Cipher_State;
class Handshake_State_13;

/**
* Generic interface for TLSv.12 endpoint
Expand Down
14 changes: 7 additions & 7 deletions src/lib/tls/tls13/tls_handshake_state_13.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,39 +8,39 @@

#include "botan/internal/tls_handshake_state_13.h"

namespace Botan::TLS {
namespace Botan::TLS::Internal {

Client_Hello_13& Handshake_State_13::store(Client_Hello_13 client_hello, const bool)
Client_Hello_13& Handshake_State_13_Base::store(Client_Hello_13 client_hello, const bool)
{
m_client_hello = std::move(client_hello);
return m_client_hello.value();
}

Server_Hello_13& Handshake_State_13::store(Server_Hello_13 server_hello, const bool)
Server_Hello_13& Handshake_State_13_Base::store(Server_Hello_13 server_hello, const bool)
{
m_server_hello = std::move(server_hello);
return m_server_hello.value();
}

Encrypted_Extensions& Handshake_State_13::store(Encrypted_Extensions encrypted_extensions, const bool)
Encrypted_Extensions& Handshake_State_13_Base::store(Encrypted_Extensions encrypted_extensions, const bool)
{
m_encrypted_extensions = std::move(encrypted_extensions);
return m_encrypted_extensions.value();
}

Certificate_13& Handshake_State_13::store(Certificate_13 certificate, const bool)
Certificate_13& Handshake_State_13_Base::store(Certificate_13 certificate, const bool)
{
m_server_certs = std::move(certificate);
return m_server_certs.value();
}

Certificate_Verify_13& Handshake_State_13::store(Certificate_Verify_13 certificate_verify, const bool)
Certificate_Verify_13& Handshake_State_13_Base::store(Certificate_Verify_13 certificate_verify, const bool)
{
m_server_verify = std::move(certificate_verify);
return m_server_verify.value();
}

Finished_13& Handshake_State_13::store(Finished_13 finished, const bool from_peer)
Finished_13& Handshake_State_13_Base::store(Finished_13 finished, const bool from_peer)
{
auto& target = ((m_side == CLIENT) == from_peer)
? m_server_finished
Expand Down
66 changes: 43 additions & 23 deletions src/lib/tls/tls13/tls_handshake_state_13.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,21 @@

namespace Botan::TLS {

class BOTAN_TEST_API Handshake_State_13
namespace Internal {
class Handshake_State_13_Base
{
public:
Handshake_State_13(Connection_Side side) : m_side(side) {}

const Client_Hello_13& client_hello() const { return get(m_client_hello); }
const Server_Hello_13& server_hello() const { return get(m_server_hello); }
const Encrypted_Extensions& encrypted_extensions() const { return get(m_encrypted_extensions); }
const Certificate_13& certificate() const { return get(m_server_certs); }
const Certificate_Verify_13& certificate_verify() const { return get(m_server_verify); }
const Finished_13& client_finished() const { return get(m_server_finished); }
const Finished_13& server_finished() const { return get(m_client_finished); }
const Finished_13& client_finished() const { return get(m_client_finished); }
const Finished_13& server_finished() const { return get(m_server_finished); }

protected:
Handshake_State_13_Base(Connection_Side whoami) : m_side(whoami) {}

Client_Hello_13& store(Client_Hello_13 client_hello, const bool from_peer);
Server_Hello_13& store(Server_Hello_13 server_hello, const bool from_peer);
Encrypted_Extensions& store(Encrypted_Extensions encrypted_extensions, const bool from_peer);
Expand All @@ -50,7 +51,6 @@ class BOTAN_TEST_API Handshake_State_13
return opt.value();
}

private:
Connection_Side m_side;

std::optional<Client_Hello_13> m_client_hello;
Expand All @@ -61,34 +61,54 @@ class BOTAN_TEST_API Handshake_State_13
std::optional<Finished_13> m_server_finished;
std::optional<Finished_13> m_client_finished;
};
}

class BOTAN_TEST_API Client_Handshake_State_13 : public Handshake_State_13
/**
* Place to store TLS handshake messages
*
* This class is used to keep all handshake messages that have been received from and sent to
* the peer as part of the TLS 1.3 handshake. Getters are provided for all message types.
* Specializations for the client and server side provide specific setters in the form of
* `sent` and `received` that only allow those types of handshake messages that are sensible
* for the respective connection side.
*
* The handshake state machine as described in RFC 8446 Appendix A is NOT validated here.
*/
template <Connection_Side whoami, typename Outbound_Message_T, typename Inbound_Message_T>
class BOTAN_TEST_API Handshake_State_13 : public Internal::Handshake_State_13_Base
{
public:
Client_Handshake_State_13() : Handshake_State_13(Connection_Side::CLIENT) {}
Handshake_State_13() : Handshake_State_13_Base(whoami) {}

decltype(auto) sent(Client_Handshake_13_Message message)
{
return std::visit([&](auto msg) -> Handshake_Message_13_Ref
decltype(auto) sent(Outbound_Message_T message)
{
return store(std::move(msg), false);
}, std::move(message));
}
return std::visit([&](auto msg) -> Handshake_Message_13_Ref
{
return store(std::move(msg), false);
}, std::move(message));
}

decltype(auto) received(Handshake_Message_13 message)
{
return std::visit([&](auto msg) -> Server_Handshake_13_Message_Ref
decltype(auto) received(Handshake_Message_13 message)
{
if constexpr (std::is_constructible_v<Server_Handshake_13_Message, decltype(msg)>)
return std::visit([&](auto msg) -> as_wrapped_references_t<Inbound_Message_T>
{
return store(std::move(msg), true);
}
if constexpr(std::is_constructible_v<Inbound_Message_T, decltype(msg)>)
{
return store(std::move(msg), true);
}

throw TLS_Exception(Alert::UNEXPECTED_MESSAGE, "client received an illegal handshake message");
}, std::move(message));
}
throw TLS_Exception(Alert::UNEXPECTED_MESSAGE, "received an illegal handshake message");
}, std::move(message));
}
};

using Client_Handshake_State_13 = Handshake_State_13<Connection_Side::CLIENT,
Client_Handshake_13_Message,
Server_Handshake_13_Message>;

using Server_Handshake_State_13 = Handshake_State_13<Connection_Side::SERVER,
Server_Handshake_13_Message,
Client_Handshake_13_Message>;
}

#endif
60 changes: 55 additions & 5 deletions src/tests/test_tls_handshake_state_13.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Test::Result CHECK(const char* name, FunT check_fun)
}

const auto client_hello_message = Botan::hex_decode( // from RFC 8448
"01 00 00 c0 03 03 cb"
"03 03 cb"
"34 ec b1 e7 81 63 ba 1c 38 c6 da cb 19 6a 6d ff a2 1a 8d 99 12"
"ec 18 a2 ef 62 83 02 4d ec e7 00 00 06 13 01 13 03 13 02 01 00"
"00 91 00 00 00 0b 00 09 00 00 06 73 65 72 76 65 72 ff 01 00 01"
Expand All @@ -50,16 +50,51 @@ const auto client_hello_message = Botan::hex_decode( // from RFC 8448
"01 04 02 05 02 06 02 02 02 00 2d 00 02 01 01 00 1c 00 02 40 01");

const auto server_hello_message = Botan::hex_decode(
"02 00 00 56 03 03 a6"
"03 03 a6"
"af 06 a4 12 18 60 dc 5e 6e 60 24 9c d3 4c 95 93 0c 8a c5 cb 14"
"34 da c1 55 77 2e d3 e2 69 28 00 13 01 00 00 2e 00 33 00 24 00"
"1d 00 20 c9 82 88 76 11 20 95 fe 66 76 2b db f7 c6 72 e1 56 d6"
"cc 25 3b 83 3d f1 dd 69 b1 b0 4e 75 1f 0f 00 2b 00 02 03 04");

const auto server_finished_message = Botan::hex_decode(
"9b 9b 14 1d 90 63 37 fb"
"d2 cb dc e7 1d f4 de da 4a b4 2c 30"
"95 72 cb 7f ff ee 54 54 b7 8f 07 18");

const auto client_finished_message = Botan::hex_decode(
"14 00 00 20 a8 ec 43 6d 67 76 34 ae 52 5a c1"
"a8 ec 43 6d 67 76 34 ae 52 5a c1"
"fc eb e1 1a 03 9e c1 76 94 fa c6 e9 85 27 b6 42 f2 ed d5 ce 61");

std::vector<Test::Result> finished_message_handling()
{
return
{
CHECK("Client sends and receives Finished messages", [&](auto& result)
{
Client_Handshake_State_13 state;

Finished_13 client_finished(client_finished_message);

auto client_fin = state.sent(std::move(client_finished));
result.require("client can send client finished",
std::holds_alternative<std::reference_wrapper<Finished_13>>(client_fin));
result.test_throws("not stored as server Finished", [&]
{
state.server_finished();
});
result.test_eq("correct client Finished stored", state.client_finished().serialize(), client_finished_message);

Finished_13 server_finished(server_finished_message);

auto server_fin = state.received(std::move(server_finished));
result.require("client can receive server finished",
std::holds_alternative<std::reference_wrapper<Finished_13>>(server_fin));
result.test_eq("correct client Finished stored", state.client_finished().serialize(), client_finished_message);
result.test_eq("correct server Finished stored", state.server_finished().serialize(), server_finished_message);
}),
};
}

std::vector<Test::Result> handshake_message_filtering()
{
return
Expand All @@ -68,25 +103,40 @@ std::vector<Test::Result> handshake_message_filtering()
{
Client_Handshake_State_13 state;

Client_Hello_13 client_hello({client_hello_message.cbegin()+4, client_hello_message.cend()});
Client_Hello_13 client_hello(client_hello_message);

auto filtered = state.sent(std::move(client_hello));
result.confirm("client can send client hello",
std::holds_alternative<std::reference_wrapper<Client_Hello_13>>(filtered));

result.test_eq("correct client hello stored", state.client_hello().serialize(), client_hello_message);

result.template test_throws<TLS_Exception>("client cannot receive client hello",
"client received an illegal handshake message", [&]
"received an illegal handshake message", [&]
{
state.received(std::move(client_hello));
});
}),
CHECK("Client with server hello", [&](auto& result)
{
Client_Handshake_State_13 state;

Server_Hello_13 server_hello(server_hello_message);

auto filtered = state.received(std::move(server_hello));
result.confirm("client can receive server hello",
std::holds_alternative<std::reference_wrapper<Server_Hello_13>>(filtered));

result.test_eq("correct server hello stored", state.server_hello().serialize(), server_hello_message);
}),
};
}

} // namespace

namespace Botan_Tests {
BOTAN_REGISTER_TEST_FN("tls", "tls_handshake_state_13",
finished_message_handling,
handshake_message_filtering);
}

Expand Down

0 comments on commit 91d05f0

Please sign in to comment.