diff --git a/src/lib/tls/tls13/tls_channel_impl_13.h b/src/lib/tls/tls13/tls_channel_impl_13.h index 5e217a4ffd8..89013ebb739 100644 --- a/src/lib/tls/tls13/tls_channel_impl_13.h +++ b/src/lib/tls/tls13/tls_channel_impl_13.h @@ -21,7 +21,6 @@ namespace TLS { class Connection_Sequence_Numbers; class Connection_Cipher_State; -class Handshake_State_13; /** * Generic interface for TLSv.12 endpoint diff --git a/src/lib/tls/tls13/tls_handshake_state_13.cpp b/src/lib/tls/tls13/tls_handshake_state_13.cpp index 37c04382713..315f3f7403c 100644 --- a/src/lib/tls/tls13/tls_handshake_state_13.cpp +++ b/src/lib/tls/tls13/tls_handshake_state_13.cpp @@ -8,39 +8,39 @@ #include "botan/internal/tls_handshake_state_13.h" -namespace Botan::TLS { +namespace Botan::TLS::Internal { -Client_Hello_13& Handshake_State_13::store(Client_Hello_13 client_hello, const bool) +Client_Hello_13& Handshake_State_13_Base::store(Client_Hello_13 client_hello, const bool) { m_client_hello = std::move(client_hello); return m_client_hello.value(); } -Server_Hello_13& Handshake_State_13::store(Server_Hello_13 server_hello, const bool) +Server_Hello_13& Handshake_State_13_Base::store(Server_Hello_13 server_hello, const bool) { m_server_hello = std::move(server_hello); return m_server_hello.value(); } -Encrypted_Extensions& Handshake_State_13::store(Encrypted_Extensions encrypted_extensions, const bool) +Encrypted_Extensions& Handshake_State_13_Base::store(Encrypted_Extensions encrypted_extensions, const bool) { m_encrypted_extensions = std::move(encrypted_extensions); return m_encrypted_extensions.value(); } -Certificate_13& Handshake_State_13::store(Certificate_13 certificate, const bool) +Certificate_13& Handshake_State_13_Base::store(Certificate_13 certificate, const bool) { m_server_certs = std::move(certificate); return m_server_certs.value(); } -Certificate_Verify_13& Handshake_State_13::store(Certificate_Verify_13 certificate_verify, const bool) +Certificate_Verify_13& Handshake_State_13_Base::store(Certificate_Verify_13 certificate_verify, const bool) { m_server_verify = std::move(certificate_verify); return m_server_verify.value(); } -Finished_13& Handshake_State_13::store(Finished_13 finished, const bool from_peer) +Finished_13& Handshake_State_13_Base::store(Finished_13 finished, const bool from_peer) { auto& target = ((m_side == CLIENT) == from_peer) ? m_server_finished diff --git a/src/lib/tls/tls13/tls_handshake_state_13.h b/src/lib/tls/tls13/tls_handshake_state_13.h index 9999a0296e8..976fbeb006a 100644 --- a/src/lib/tls/tls13/tls_handshake_state_13.h +++ b/src/lib/tls/tls13/tls_handshake_state_13.h @@ -20,20 +20,21 @@ namespace Botan::TLS { -class BOTAN_TEST_API Handshake_State_13 +namespace Internal { +class Handshake_State_13_Base { public: - Handshake_State_13(Connection_Side side) : m_side(side) {} - const Client_Hello_13& client_hello() const { return get(m_client_hello); } const Server_Hello_13& server_hello() const { return get(m_server_hello); } const Encrypted_Extensions& encrypted_extensions() const { return get(m_encrypted_extensions); } const Certificate_13& certificate() const { return get(m_server_certs); } const Certificate_Verify_13& certificate_verify() const { return get(m_server_verify); } - const Finished_13& client_finished() const { return get(m_server_finished); } - const Finished_13& server_finished() const { return get(m_client_finished); } + const Finished_13& client_finished() const { return get(m_client_finished); } + const Finished_13& server_finished() const { return get(m_server_finished); } protected: + Handshake_State_13_Base(Connection_Side whoami) : m_side(whoami) {} + Client_Hello_13& store(Client_Hello_13 client_hello, const bool from_peer); Server_Hello_13& store(Server_Hello_13 server_hello, const bool from_peer); Encrypted_Extensions& store(Encrypted_Extensions encrypted_extensions, const bool from_peer); @@ -50,7 +51,6 @@ class BOTAN_TEST_API Handshake_State_13 return opt.value(); } - private: Connection_Side m_side; std::optional m_client_hello; @@ -61,34 +61,54 @@ class BOTAN_TEST_API Handshake_State_13 std::optional m_server_finished; std::optional m_client_finished; }; +} -class BOTAN_TEST_API Client_Handshake_State_13 : public Handshake_State_13 +/** + * Place to store TLS handshake messages + * + * This class is used to keep all handshake messages that have been received from and sent to + * the peer as part of the TLS 1.3 handshake. Getters are provided for all message types. + * Specializations for the client and server side provide specific setters in the form of + * `sent` and `received` that only allow those types of handshake messages that are sensible + * for the respective connection side. + * + * The handshake state machine as described in RFC 8446 Appendix A is NOT validated here. + */ +template +class BOTAN_TEST_API Handshake_State_13 : public Internal::Handshake_State_13_Base { public: - Client_Handshake_State_13() : Handshake_State_13(Connection_Side::CLIENT) {} + Handshake_State_13() : Handshake_State_13_Base(whoami) {} - decltype(auto) sent(Client_Handshake_13_Message message) - { - return std::visit([&](auto msg) -> Handshake_Message_13_Ref + decltype(auto) sent(Outbound_Message_T message) { - return store(std::move(msg), false); - }, std::move(message)); - } + return std::visit([&](auto msg) -> Handshake_Message_13_Ref + { + return store(std::move(msg), false); + }, std::move(message)); + } - decltype(auto) received(Handshake_Message_13 message) - { - return std::visit([&](auto msg) -> Server_Handshake_13_Message_Ref + decltype(auto) received(Handshake_Message_13 message) { - if constexpr (std::is_constructible_v) + return std::visit([&](auto msg) -> as_wrapped_references_t { - return store(std::move(msg), true); - } + if constexpr(std::is_constructible_v) + { + return store(std::move(msg), true); + } - throw TLS_Exception(Alert::UNEXPECTED_MESSAGE, "client received an illegal handshake message"); - }, std::move(message)); - } + throw TLS_Exception(Alert::UNEXPECTED_MESSAGE, "received an illegal handshake message"); + }, std::move(message)); + } }; +using Client_Handshake_State_13 = Handshake_State_13; + +using Server_Handshake_State_13 = Handshake_State_13; } #endif diff --git a/src/tests/test_tls_handshake_state_13.cpp b/src/tests/test_tls_handshake_state_13.cpp index 922401ce5fd..3c4b2f63463 100644 --- a/src/tests/test_tls_handshake_state_13.cpp +++ b/src/tests/test_tls_handshake_state_13.cpp @@ -38,7 +38,7 @@ Test::Result CHECK(const char* name, FunT check_fun) } const auto client_hello_message = Botan::hex_decode( // from RFC 8448 - "01 00 00 c0 03 03 cb" + "03 03 cb" "34 ec b1 e7 81 63 ba 1c 38 c6 da cb 19 6a 6d ff a2 1a 8d 99 12" "ec 18 a2 ef 62 83 02 4d ec e7 00 00 06 13 01 13 03 13 02 01 00" "00 91 00 00 00 0b 00 09 00 00 06 73 65 72 76 65 72 ff 01 00 01" @@ -50,16 +50,51 @@ const auto client_hello_message = Botan::hex_decode( // from RFC 8448 "01 04 02 05 02 06 02 02 02 00 2d 00 02 01 01 00 1c 00 02 40 01"); const auto server_hello_message = Botan::hex_decode( - "02 00 00 56 03 03 a6" + "03 03 a6" "af 06 a4 12 18 60 dc 5e 6e 60 24 9c d3 4c 95 93 0c 8a c5 cb 14" "34 da c1 55 77 2e d3 e2 69 28 00 13 01 00 00 2e 00 33 00 24 00" "1d 00 20 c9 82 88 76 11 20 95 fe 66 76 2b db f7 c6 72 e1 56 d6" "cc 25 3b 83 3d f1 dd 69 b1 b0 4e 75 1f 0f 00 2b 00 02 03 04"); +const auto server_finished_message = Botan::hex_decode( + "9b 9b 14 1d 90 63 37 fb" + "d2 cb dc e7 1d f4 de da 4a b4 2c 30" + "95 72 cb 7f ff ee 54 54 b7 8f 07 18"); + const auto client_finished_message = Botan::hex_decode( - "14 00 00 20 a8 ec 43 6d 67 76 34 ae 52 5a c1" + "a8 ec 43 6d 67 76 34 ae 52 5a c1" "fc eb e1 1a 03 9e c1 76 94 fa c6 e9 85 27 b6 42 f2 ed d5 ce 61"); +std::vector finished_message_handling() + { + return + { + CHECK("Client sends and receives Finished messages", [&](auto& result) + { + Client_Handshake_State_13 state; + + Finished_13 client_finished(client_finished_message); + + auto client_fin = state.sent(std::move(client_finished)); + result.require("client can send client finished", + std::holds_alternative>(client_fin)); + result.test_throws("not stored as server Finished", [&] + { + state.server_finished(); + }); + result.test_eq("correct client Finished stored", state.client_finished().serialize(), client_finished_message); + + Finished_13 server_finished(server_finished_message); + + auto server_fin = state.received(std::move(server_finished)); + result.require("client can receive server finished", + std::holds_alternative>(server_fin)); + result.test_eq("correct client Finished stored", state.client_finished().serialize(), client_finished_message); + result.test_eq("correct server Finished stored", state.server_finished().serialize(), server_finished_message); + }), + }; + } + std::vector handshake_message_filtering() { return @@ -68,18 +103,32 @@ std::vector handshake_message_filtering() { Client_Handshake_State_13 state; - Client_Hello_13 client_hello({client_hello_message.cbegin()+4, client_hello_message.cend()}); + Client_Hello_13 client_hello(client_hello_message); auto filtered = state.sent(std::move(client_hello)); result.confirm("client can send client hello", std::holds_alternative>(filtered)); + result.test_eq("correct client hello stored", state.client_hello().serialize(), client_hello_message); + result.template test_throws("client cannot receive client hello", - "client received an illegal handshake message", [&] + "received an illegal handshake message", [&] { state.received(std::move(client_hello)); }); }), + CHECK("Client with server hello", [&](auto& result) + { + Client_Handshake_State_13 state; + + Server_Hello_13 server_hello(server_hello_message); + + auto filtered = state.received(std::move(server_hello)); + result.confirm("client can receive server hello", + std::holds_alternative>(filtered)); + + result.test_eq("correct server hello stored", state.server_hello().serialize(), server_hello_message); + }), }; } @@ -87,6 +136,7 @@ std::vector handshake_message_filtering() namespace Botan_Tests { BOTAN_REGISTER_TEST_FN("tls", "tls_handshake_state_13", + finished_message_handling, handshake_message_filtering); }