Skip to content

Commit

Permalink
TLS 1.2 refactoring and cleanup
Browse files Browse the repository at this point in the history
Co-authored-by: René Meusel <rene.meusel@nexenio.com>
  • Loading branch information
Hannes Rantzsch and reneme committed Jan 14, 2022
1 parent 381bf63 commit 1b4c645
Show file tree
Hide file tree
Showing 32 changed files with 288 additions and 403 deletions.
1 change: 1 addition & 0 deletions src/lib/tls/info.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,6 @@ rng
rsa
sha2_32
sha2_64
tls12
x509
</requires>
10 changes: 4 additions & 6 deletions src/lib/tls/msg_cert_req.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,22 +49,20 @@ Certificate_Req::Certificate_Req(const Protocol_Version& protocol_version,
Handshake_Hash& hash,
const Policy& policy,
const std::vector<X509_DN>& ca_certs) :
m_impl( protocol_version == Protocol_Version::TLS_V13
? TLS_Message_Factory::create<Certificate_Req_Impl, Protocol_Version::TLS_V13>()
: TLS_Message_Factory::create<Certificate_Req_Impl, Protocol_Version::TLS_V12>(io, hash, policy, ca_certs))
m_impl(Message_Factory::create<Certificate_Req_Impl>(protocol_version, io, hash, policy, ca_certs))
{
}

/**
* Deserialize a Certificate Request message
*/
Certificate_Req::Certificate_Req(const Protocol_Version& protocol_version, const std::vector<uint8_t>& buf) :
m_impl( protocol_version == Protocol_Version::TLS_V13
? TLS_Message_Factory::create<Certificate_Req_Impl, Protocol_Version::TLS_V13>()
: TLS_Message_Factory::create<Certificate_Req_Impl, Protocol_Version::TLS_V12>(buf))
m_impl(Message_Factory::create<Certificate_Req_Impl>(protocol_version, buf))
{
}

// Needed for std::unique_ptr<> m_impl member, as *_Impl type
// is available as a forward declaration in the header only.
Certificate_Req::~Certificate_Req() = default;

/**
Expand Down
2 changes: 0 additions & 2 deletions src/lib/tls/msg_cert_req_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ namespace TLS {

Certificate_Req_Impl::Certificate_Req_Impl() = default;

Certificate_Req_Impl::~Certificate_Req_Impl() = default;

Handshake_Type Certificate_Req_Impl::type() const
{
return CERTIFICATE_REQUEST;
Expand Down
2 changes: 0 additions & 2 deletions src/lib/tls/msg_cert_req_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ class Certificate_Req_Impl : public Handshake_Message
virtual const std::vector<Signature_Scheme>& signature_schemes() const = 0;

explicit Certificate_Req_Impl();

virtual ~Certificate_Req_Impl() = 0;
};
}

Expand Down
10 changes: 4 additions & 6 deletions src/lib/tls/msg_cert_verify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,20 @@ Certificate_Verify::Certificate_Verify(Handshake_IO& io,
const Policy& policy,
RandomNumberGenerator& rng,
const Private_Key* priv_key) :
m_impl( state.version() == Protocol_Version::TLS_V13
? TLS_Message_Factory::create<Certificate_Verify_Impl, Protocol_Version::TLS_V13>(io, state, policy, rng, priv_key)
: TLS_Message_Factory::create<Certificate_Verify_Impl, Protocol_Version::TLS_V12>(io, state, policy, rng, priv_key))
m_impl(Message_Factory::create<Certificate_Verify_Impl>(state.version(), io, state, policy, rng, priv_key))
{
}

/*
* Deserialize a Certificate Verify message
*/
Certificate_Verify::Certificate_Verify(const Protocol_Version& protocol_version, const std::vector<uint8_t>& buf) :
m_impl( protocol_version == Protocol_Version::TLS_V13
? TLS_Message_Factory::create<Certificate_Verify_Impl, Protocol_Version::TLS_V13>(buf)
: TLS_Message_Factory::create<Certificate_Verify_Impl, Protocol_Version::TLS_V12>(buf))
m_impl(Message_Factory::create<Certificate_Verify_Impl>(protocol_version, buf))
{
}

// Needed for std::unique_ptr<> m_impl member, as *_Impl type
// is available as a forward declaration in the header only.
Certificate_Verify::~Certificate_Verify() = default;

/*
Expand Down
2 changes: 0 additions & 2 deletions src/lib/tls/msg_cert_verify_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ Certificate_Verify_Impl::Certificate_Verify_Impl(const std::vector<uint8_t>& buf
reader.assert_done();
}

Certificate_Verify_Impl::~Certificate_Verify_Impl() = default;

/*
* Serialize a Certificate Verify message
*/
Expand Down
12 changes: 5 additions & 7 deletions src/lib/tls/msg_cert_verify_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,14 @@ class Certificate_Verify_Impl : public Handshake_Message
const Handshake_State& state,
const Policy& policy) const;

explicit Certificate_Verify_Impl(Handshake_IO& io,
Handshake_State& state,
const Policy& policy,
RandomNumberGenerator& rng,
const Private_Key* key);
Certificate_Verify_Impl(Handshake_IO& io,
Handshake_State& state,
const Policy& policy,
RandomNumberGenerator& rng,
const Private_Key* key);

explicit Certificate_Verify_Impl(const std::vector<uint8_t>& buf);

virtual ~Certificate_Verify_Impl() = 0;

std::vector<uint8_t> serialize() const override;
private:
std::vector<uint8_t> m_signature;
Expand Down
10 changes: 4 additions & 6 deletions src/lib/tls/msg_certificate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ Certificate::Certificate(const Protocol_Version& protocol_version,
Handshake_IO& io,
Handshake_Hash& hash,
const std::vector<X509_Certificate>& cert_list) :
m_impl( protocol_version == Protocol_Version::TLS_V13
? TLS_Message_Factory::create<Certificate_Impl, Protocol_Version::TLS_V13>()
: TLS_Message_Factory::create<Certificate_Impl, Protocol_Version::TLS_V12>(io, hash, cert_list))
m_impl(Message_Factory::create<Certificate_Impl>(protocol_version, io, hash, cert_list))
{
}

Expand All @@ -56,12 +54,12 @@ Certificate::Certificate(const Protocol_Version& protocol_version,
*/
Certificate::Certificate(const Protocol_Version& protocol_version,
const std::vector<uint8_t>& buf, const Policy& policy) :
m_impl( protocol_version == Protocol_Version::TLS_V13
? TLS_Message_Factory::create<Certificate_Impl, Protocol_Version::TLS_V13>()
: TLS_Message_Factory::create<Certificate_Impl, Protocol_Version::TLS_V12>(buf, policy))
m_impl(Message_Factory::create<Certificate_Impl>(protocol_version, buf, policy))
{
}

// Needed for std::unique_ptr<> m_impl member, as *_Impl type
// is available as a forward declaration in the header only.
Certificate::~Certificate() = default;

/**
Expand Down
4 changes: 0 additions & 4 deletions src/lib/tls/msg_certificate_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@ Handshake_Type Certificate_Impl::type() const
return CERTIFICATE;
}

Certificate_Impl::Certificate_Impl() = default;

Certificate_Impl::~Certificate_Impl() = default;

}

}
4 changes: 1 addition & 3 deletions src/lib/tls/msg_certificate_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ class Certificate_Impl : public Handshake_Message
virtual size_t count() const = 0;
virtual bool empty() const = 0;

explicit Certificate_Impl();

virtual ~Certificate_Impl() = 0;
explicit Certificate_Impl() = default;
};

}
Expand Down
20 changes: 10 additions & 10 deletions src/lib/tls/msg_client_hello.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include <botan/internal/stl_util.h>
#include <botan/internal/msg_client_hello_impl.h>
#include <botan/internal/msg_client_hello_impl_12.h>
#include <botan/internal/msg_client_hello_impl_13.h>
#include <botan/internal/tls_message_factory.h>

namespace Botan {
Expand Down Expand Up @@ -65,9 +64,7 @@ Client_Hello::Client_Hello(Handshake_IO& io,
const std::vector<uint8_t>& reneg_info,
const Client_Hello::Settings& client_settings,
const std::vector<std::string>& next_protocols) :
m_impl(client_settings.protocol_version() == Protocol_Version::TLS_V13
? TLS_Message_Factory::create<Client_Hello_Impl, Protocol_Version::TLS_V13>(io, hash, policy, cb, rng, reneg_info, client_settings, next_protocols)
: TLS_Message_Factory::create<Client_Hello_Impl, Protocol_Version::TLS_V12>(io, hash, policy, cb, rng, reneg_info, client_settings, next_protocols))
m_impl(Message_Factory::create<Client_Hello_Impl>(client_settings.protocol_version(), io, hash, policy, cb, rng, reneg_info, client_settings, next_protocols))
{
}

Expand All @@ -82,9 +79,7 @@ Client_Hello::Client_Hello(Handshake_IO& io,
const std::vector<uint8_t>& reneg_info,
const Session& session,
const std::vector<std::string>& next_protocols) :
m_impl(session.version() == Protocol_Version::TLS_V13
? TLS_Message_Factory::create<Client_Hello_Impl, Protocol_Version::TLS_V13>(io, hash, policy, cb, rng, reneg_info, session, next_protocols)
: TLS_Message_Factory::create<Client_Hello_Impl, Protocol_Version::TLS_V12>(io, hash, policy, cb, rng, reneg_info, session, next_protocols))
m_impl(Message_Factory::create<Client_Hello_Impl>(session.version(), io, hash, policy, cb, rng, reneg_info, session, next_protocols))
{
}

Expand All @@ -95,11 +90,16 @@ Client_Hello::Client_Hello(const std::vector<uint8_t>& buf)
{
auto supported_versions = Client_Hello_Impl(buf).supported_versions();

m_impl = value_exists(supported_versions, Protocol_Version(Protocol_Version::TLS_V13))
? TLS_Message_Factory::create<Client_Hello_Impl, Protocol_Version::TLS_V13>(buf)
: TLS_Message_Factory::create<Client_Hello_Impl, Protocol_Version::TLS_V12>(buf);
const auto protocol_version =
value_exists(supported_versions, Protocol_Version(Protocol_Version::TLS_V13))
? Protocol_Version::TLS_V13
: Protocol_Version::TLS_V12;

m_impl = Message_Factory::create<Client_Hello_Impl>(protocol_version, buf);
}

// Needed for std::unique_ptr<> m_impl member, as *_Impl type
// is available as a forward declaration in the header only.
Client_Hello::~Client_Hello() = default;


Expand Down
40 changes: 16 additions & 24 deletions src/lib/tls/msg_client_hello_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,27 +41,25 @@ enum {
};

std::vector<uint8_t> make_hello_random(RandomNumberGenerator& rng,
const Policy& policy)
{
std::vector<uint8_t> buf(32);
rng.randomize(buf.data(), buf.size());

auto sha256 = HashFunction::create_or_throw("SHA-256");
sha256->update(buf);
sha256->final(buf);

if(policy.include_time_in_hello_random())
const Policy& policy)
{
const uint32_t time32 = static_cast<uint32_t>(
std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()));
std::vector<uint8_t> buf(32);
rng.randomize(buf.data(), buf.size());

store_be(time32, buf.data());
}
auto sha256 = HashFunction::create_or_throw("SHA-256");
sha256->update(buf);
sha256->final(buf);

return buf;
}
if(policy.include_time_in_hello_random())
{
const uint32_t time32 = static_cast<uint32_t>(
std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()));

Client_Hello_Impl::Client_Hello_Impl() = default;
store_be(time32, buf.data());
}

return buf;
}

/*
* Create a new Client Hello message
Expand Down Expand Up @@ -236,9 +234,6 @@ Client_Hello_Impl::Client_Hello_Impl(const std::vector<uint8_t>& buf)
}
}


Client_Hello_Impl::~Client_Hello_Impl() = default;

Handshake_Type Client_Hello_Impl::type() const
{
return CLIENT_HELLO;
Expand Down Expand Up @@ -330,10 +325,7 @@ std::vector<uint8_t> Client_Hello_Impl::cookie_input_data() const
*/
bool Client_Hello_Impl::offered_suite(uint16_t ciphersuite) const
{
for(size_t i = 0; i != m_suites.size(); ++i)
if(m_suites[i] == ciphersuite)
return true;
return false;
return std::find(m_suites.cbegin(), m_suites.cend(), ciphersuite) != m_suites.cend();
}

std::vector<Signature_Scheme> Client_Hello_Impl::signature_schemes() const
Expand Down
40 changes: 19 additions & 21 deletions src/lib/tls/msg_client_hello_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,30 +33,28 @@ class Policy;
class Client_Hello_Impl : public Handshake_Message
{
public:
explicit Client_Hello_Impl();

explicit Client_Hello_Impl(Handshake_IO& io,
Handshake_Hash& hash,
const Policy& policy,
Callbacks& cb,
RandomNumberGenerator& rng,
const std::vector<uint8_t>& reneg_info,
const Client_Hello::Settings& client_settings,
const std::vector<std::string>& next_protocols);

explicit Client_Hello_Impl(Handshake_IO& io,
Handshake_Hash& hash,
const Policy& policy,
Callbacks& cb,
RandomNumberGenerator& rng,
const std::vector<uint8_t>& reneg_info,
const Session& resumed_session,
const std::vector<std::string>& next_protocols);
explicit Client_Hello_Impl() = default;

Client_Hello_Impl(Handshake_IO& io,
Handshake_Hash& hash,
const Policy& policy,
Callbacks& cb,
RandomNumberGenerator& rng,
const std::vector<uint8_t>& reneg_info,
const Client_Hello::Settings& client_settings,
const std::vector<std::string>& next_protocols);

Client_Hello_Impl(Handshake_IO& io,
Handshake_Hash& hash,
const Policy& policy,
Callbacks& cb,
RandomNumberGenerator& rng,
const std::vector<uint8_t>& reneg_info,
const Session& resumed_session,
const std::vector<std::string>& next_protocols);

explicit Client_Hello_Impl(const std::vector<uint8_t>& buf);

virtual ~Client_Hello_Impl();

Handshake_Type type() const override;

Protocol_Version version() const;
Expand Down
10 changes: 4 additions & 6 deletions src/lib/tls/msg_finished.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ namespace TLS {
Finished::Finished(Handshake_IO& io,
Handshake_State& state,
Connection_Side side) :
m_impl( state.version() == Protocol_Version::TLS_V13
? TLS_Message_Factory::create<Finished_Impl, Protocol_Version::TLS_V13>(io, state, side)
: TLS_Message_Factory::create<Finished_Impl, Protocol_Version::TLS_V12>(io, state, side))
m_impl(Message_Factory::create<Finished_Impl>(state.version(), io, state, side))
{
}

Expand All @@ -43,12 +41,12 @@ std::vector<uint8_t> Finished::serialize() const
* Deserialize a Finished message
*/
Finished::Finished(const Protocol_Version& protocol_version, const std::vector<uint8_t>& buf):
m_impl( protocol_version == Protocol_Version::TLS_V13
? TLS_Message_Factory::create<Finished_Impl, Protocol_Version::TLS_V13>(buf)
: TLS_Message_Factory::create<Finished_Impl, Protocol_Version::TLS_V12>(buf))
m_impl(Message_Factory::create<Finished_Impl>(protocol_version, buf))
{
}

// Needed for std::unique_ptr<> m_impl member, as *_Impl type
// is available as a forward declaration in the header only.
Finished::~Finished() = default;


Expand Down
14 changes: 8 additions & 6 deletions src/lib/tls/msg_finished_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,9 @@ std::vector<uint8_t> finished_compute_verify(const Handshake_State& state,

std::vector<uint8_t> input;
std::vector<uint8_t> label;
if(side == CLIENT)
label += std::make_pair(TLS_CLIENT_LABEL, sizeof(TLS_CLIENT_LABEL));
else
label += std::make_pair(TLS_SERVER_LABEL, sizeof(TLS_SERVER_LABEL));
label += (side == CLIENT)
? std::make_pair(TLS_CLIENT_LABEL, sizeof(TLS_CLIENT_LABEL))
: std::make_pair(TLS_SERVER_LABEL, sizeof(TLS_SERVER_LABEL));

input += state.hash().final(state.ciphersuite().prf_algo());

Expand All @@ -60,6 +59,11 @@ Finished_Impl::Finished_Impl(Handshake_IO& io,
state.hash().update(io.send(*this));
}

Handshake_Type Finished_Impl::type() const
{
return FINISHED;
}

/*
* Serialize a Finished message
*/
Expand All @@ -74,8 +78,6 @@ std::vector<uint8_t> Finished_Impl::serialize() const
Finished_Impl::Finished_Impl(const std::vector<uint8_t>& buf) : m_verification_data(buf)
{}

Finished_Impl::~Finished_Impl() = default;

std::vector<uint8_t> Finished_Impl::verify_data() const
{
return m_verification_data;
Expand Down
Loading

0 comments on commit 1b4c645

Please sign in to comment.