Skip to content

Commit

Permalink
Implement Mbed TLS Backend
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean-Der committed Mar 13, 2023
1 parent d16c250 commit 5f50cbf
Show file tree
Hide file tree
Showing 9 changed files with 700 additions and 9 deletions.
17 changes: 15 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ project(libdatachannel
set(PROJECT_DESCRIPTION "C/C++ WebRTC network library featuring Data Channels, Media Transport, and WebSockets")

# Options
option(USE_MBEDTLS "Use Mbed TLS instead of OpenSSL" OFF)
option(USE_GNUTLS "Use GnuTLS instead of OpenSSL" OFF)
option(USE_NICE "Use libnice instead of libjuice" OFF)
option(USE_SYSTEM_SRTP "Use system libSRTP" OFF)
Expand All @@ -17,11 +18,16 @@ option(WARNINGS_AS_ERRORS "Treat warnings as errors" OFF)
option(CAPI_STDCALL "Set calling convention of C API callbacks stdcall" OFF)
option(SCTP_DEBUG "Enable SCTP debugging output to verbose log" OFF)

if (USE_MBEDTLS AND USE_GNUTLS)
message(FATAL_ERROR "Both USE_MBEDTLS and USE_GNUTLS can not be enabled at the same time")
endif()


if(USE_GNUTLS)
option(USE_NETTLE "Use Nettle in libjuice" ON)
else()
option(USE_NETTLE "Use Nettle in libjuice" OFF)
if(NOT USE_SYSTEM_SRTP)
if(NOT USE_SYSTEM_SRTP AND NOT USE_MBEDTLS)
option(ENABLE_OPENSSL "Enable OpenSSL crypto engine for SRTP" ON)
endif()
endif()
Expand Down Expand Up @@ -317,11 +323,18 @@ if (USE_GNUTLS)
target_link_libraries(datachannel PRIVATE GnuTLS::GnuTLS)
target_link_libraries(datachannel-static PRIVATE GnuTLS::GnuTLS)
if (NOT NO_WEBSOCKET)
# Needed for SHA1, it should be present as GnuTLS cryptography backend
# Needed for SHA1, it should be present as GnuTLS/MbedTLS cryptography backend
find_package(Nettle REQUIRED)
target_link_libraries(datachannel PRIVATE Nettle::Nettle)
target_link_libraries(datachannel-static PRIVATE Nettle::Nettle)
endif()
elseif(USE_MBEDTLS)
find_package(MbedTLS REQUIRED)

target_compile_definitions(datachannel PRIVATE USE_MBEDTLS)
target_compile_definitions(datachannel-static PRIVATE USE_MBEDTLS)
target_link_libraries(datachannel PRIVATE MbedTLS::mbedtls)
target_link_libraries(datachannel-static PRIVATE MbedTLS::mbedtls)
else()
if(APPLE)
# This is a bug in CMake that causes it to prefer the system version over
Expand Down
235 changes: 233 additions & 2 deletions src/impl/certificate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

namespace rtc::impl {

string Certificate::fingerprint() const { return mFingerprint; }

#if USE_GNUTLS

Certificate Certificate::FromString(string crt_pem, string key_pem) {
Expand Down Expand Up @@ -111,6 +113,236 @@ Certificate::Certificate(shared_ptr<gnutls_certificate_credentials_t> creds)

gnutls_certificate_credentials_t Certificate::credentials() const { return *mCredentials; }


string make_fingerprint(gnutls_certificate_credentials_t credentials) {
auto new_crt_list = [credentials]() -> gnutls_x509_crt_t * {
gnutls_x509_crt_t *crt_list = nullptr;
unsigned int crt_list_size = 0;
gnutls::check(gnutls_certificate_get_x509_crt(credentials, 0, &crt_list, &crt_list_size));
assert(crt_list_size == 1);
return crt_list;
};

auto free_crt_list = [](gnutls_x509_crt_t *crt_list) {
gnutls_x509_crt_deinit(crt_list[0]);
gnutls_free(crt_list);
};

unique_ptr<gnutls_x509_crt_t, decltype(free_crt_list)> crt_list(new_crt_list(), free_crt_list);

return make_fingerprint(*crt_list);
}

string make_fingerprint(gnutls_x509_crt_t crt) {
const size_t size = 32;
unsigned char buffer[size];
size_t len = size;
gnutls::check(gnutls_x509_crt_get_fingerprint(crt, GNUTLS_DIG_SHA256, buffer, &len),
"X509 fingerprint error");

std::ostringstream oss;
oss << std::hex << std::uppercase << std::setfill('0');
for (size_t i = 0; i < len; ++i) {
if (i)
oss << std::setw(1) << ':';
oss << std::setw(2) << unsigned(buffer[i]);
}
return oss.str();
}

#elif USE_MBEDTLS
string make_fingerprint(shared_ptr<mbedtls_x509_crt> crt) {
uint8_t buffer[MBEDTLS_MD_MAX_SIZE];
std::stringstream fingerprint;

mbedtls::check(mbedtls_sha256(crt->raw.p, crt->raw.len, (unsigned char *) buffer, 0));

auto size = mbedtls_md_get_size(mbedtls_md_info_from_type(MBEDTLS_MD_SHA256));
for (auto i = 0; i < size; i++) {
fingerprint << std::setfill('0') << std::setw(2) << std::hex << static_cast<int>(buffer[i]);
if (i != (size - 1)) {
fingerprint << ":";
}
}

return fingerprint.str();
}

Certificate::Certificate(shared_ptr<mbedtls_x509_crt> crt, shared_ptr<mbedtls_pk_context> pk)
: mCrt(crt),
mPk(pk),
mFingerprint(make_fingerprint(crt)) {
}

Certificate Certificate::FromString(string crt_pem, string key_pem) {
PLOG_DEBUG << "Importing certificate from PEM string (MbedTLS)";
shared_ptr<mbedtls_x509_crt> crt(
new mbedtls_x509_crt,
[](mbedtls_x509_crt *p) {
mbedtls_x509_crt_free(p);
delete p;
});

shared_ptr<mbedtls_pk_context> pk(
new mbedtls_pk_context,
[](mbedtls_pk_context *p) {
mbedtls_pk_free(p);
delete p;
});

mbedtls_x509_crt_init(crt.get());
mbedtls::check(mbedtls_x509_crt_parse(crt.get(), (const unsigned char *) crt_pem.c_str(), crt_pem.length()));

mbedtls_pk_init(pk.get());
mbedtls::check(mbedtls_pk_parse_key(pk.get(), (const unsigned char *) key_pem.c_str(), key_pem.size(), NULL, 0, NULL, 0));

return Certificate(std::move(crt), std::move(pk));
}

Certificate Certificate::FromFile(const string &crt_pem_file, const string &key_pem_file,
const string &pass) {
PLOG_DEBUG << "Importing certificate from PEM file (MbedTLS): " << crt_pem_file;

shared_ptr<mbedtls_x509_crt> crt(
new mbedtls_x509_crt,
[](mbedtls_x509_crt *p) {
mbedtls_x509_crt_free(p);
delete p;
});

shared_ptr<mbedtls_pk_context> pk(
new mbedtls_pk_context,
[](mbedtls_pk_context *p) {
mbedtls_pk_free(p);
delete p;
});

mbedtls_x509_crt_init(crt.get());
mbedtls::check(mbedtls_x509_crt_parse_file(crt.get(), crt_pem_file.c_str()));

mbedtls_pk_init(pk.get());
mbedtls::check(mbedtls_pk_parse_keyfile(pk.get(), key_pem_file.c_str(), pass.c_str(), 0, NULL));

return Certificate(std::move(crt), std::move(pk));
}

Certificate Certificate::Generate(CertificateType type, const string &commonName) {
PLOG_DEBUG << "Generating certificate (MbedTLS)";

shared_ptr<mbedtls_x509_crt> crt(
new mbedtls_x509_crt,
[](mbedtls_x509_crt *p) {
mbedtls_x509_crt_free(p);
delete p;
});
mbedtls_x509_crt_init(crt.get());

shared_ptr<mbedtls_pk_context> pk(
new mbedtls_pk_context,
[](mbedtls_pk_context *p) {
mbedtls_pk_free(p);
delete p;
});
mbedtls_pk_init(pk.get());

mbedtls_entropy_context entropy;
mbedtls_ctr_drbg_context drbg;
mbedtls_x509write_cert wcrt;
mbedtls_mpi serial;

mbedtls_entropy_init(&entropy);
mbedtls_ctr_drbg_init(&drbg);
mbedtls_x509write_crt_init(&wcrt);
mbedtls_mpi_init(&serial);

try {
mbedtls::check(mbedtls_ctr_drbg_seed(
&drbg, mbedtls_entropy_func, &entropy,
reinterpret_cast<const unsigned char *>(commonName.data()), commonName.size()));

switch (type) {
// RFC 8827 WebRTC Security Architecture 6.5. Communications Security
// All implementations MUST support DTLS 1.2 with the
// TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 cipher suite and the P-256 curve
// See https://www.rfc-editor.org/rfc/rfc8827.html#section-6.5
//case CertificateType::Default:
case CertificateType::Ecdsa: {
mbedtls::check(mbedtls_pk_setup(pk.get(), mbedtls_pk_info_from_type(MBEDTLS_PK_ECKEY)));
mbedtls::check(mbedtls_ecp_gen_key(MBEDTLS_ECP_DP_SECP256R1, mbedtls_pk_ec(*pk.get()), mbedtls_ctr_drbg_random, &drbg), "Unable to generate ECDSA P-256 key pair");
break;
}
case CertificateType::Default:
case CertificateType::Rsa: {
const unsigned int nbits = 2048;
const int exponent = 65537;

mbedtls::check(mbedtls_pk_setup(pk.get(), mbedtls_pk_info_from_type(MBEDTLS_PK_RSA)));
mbedtls::check(mbedtls_rsa_gen_key(mbedtls_pk_rsa(*pk.get()), mbedtls_ctr_drbg_random, &drbg, nbits, exponent), "Unable to generate RSA key pair");
break;
}
default:
throw std::invalid_argument("Unknown certificate type");
}

auto now = std::chrono::time_point_cast<std::chrono::seconds>(std::chrono::system_clock::now());
string notBefore = mbedtls::format_time(now - std::chrono::hours(1));
string notAfter = mbedtls::format_time(now + std::chrono::hours(24 * 365));

const size_t serialBufferSize = 16;
unsigned char serialBuffer[serialBufferSize];
mbedtls::check(mbedtls_ctr_drbg_random(&drbg, serialBuffer, serialBufferSize));
mbedtls::check(mbedtls_mpi_read_binary(&serial, serialBuffer, serialBufferSize));

std::string name = std::string("O=" +commonName+ ",CN=" + commonName);
mbedtls::check(mbedtls_x509write_crt_set_serial(&wcrt, &serial));
mbedtls::check(mbedtls_x509write_crt_set_subject_name(&wcrt, name.c_str()));
mbedtls::check(mbedtls_x509write_crt_set_issuer_name(&wcrt, name.c_str()));
mbedtls::check(mbedtls_x509write_crt_set_validity(&wcrt, notBefore.c_str(), notAfter.c_str()));

mbedtls_x509write_crt_set_version(&wcrt, MBEDTLS_X509_CRT_VERSION_3);
mbedtls_x509write_crt_set_subject_key(&wcrt, pk.get());
mbedtls_x509write_crt_set_issuer_key(&wcrt, pk.get());
mbedtls_x509write_crt_set_md_alg(&wcrt, MBEDTLS_MD_SHA256);

const size_t certificateBufferSize = 4096;
unsigned char certificateBuffer[certificateBufferSize];
std::memset(certificateBuffer, 0, certificateBufferSize);

auto certificateLen = mbedtls_x509write_crt_der(&wcrt, certificateBuffer, certificateBufferSize, mbedtls_ctr_drbg_random, &drbg);
if (certificateLen <= 0){
throw std::runtime_error("Certificate generation failed");
}

mbedtls::check(mbedtls_x509_crt_parse_der(crt.get(), (certificateBuffer + certificateBufferSize - certificateLen), certificateLen));
} catch (...) {
mbedtls_entropy_free(&entropy);
mbedtls_ctr_drbg_free(&drbg);
throw;
}

return Certificate(std::move(crt), std::move(pk));
}

std::tuple<shared_ptr<mbedtls_x509_crt>, shared_ptr<mbedtls_pk_context>> Certificate::credentials() const {
return std::tuple<shared_ptr<mbedtls_x509_crt>, shared_ptr<mbedtls_pk_context>>{mCrt, mPk };
}

#elif USE_GNUTLS

// TODO
Certificate::Certificate(gnutls_x509_crt_t crt, gnutls_x509_privkey_t privkey)
: mCredentials(gnutls::new_credentials(), gnutls::free_credentials),
mFingerprint(make_fingerprint(crt)) {

gnutls::check(gnutls_certificate_set_x509_key(*mCredentials, &crt, 1, privkey),
"Unable to set certificate and key pair in credentials");
}

Certificate::Certificate(shared_ptr<gnutls_certificate_credentials_t> creds)
: mCredentials(std::move(creds)), mFingerprint(make_fingerprint(*mCredentials)) {}

gnutls_certificate_credentials_t Certificate::credentials() const { return *mCredentials; }

string Certificate::fingerprint() const { return mFingerprint; }

string make_fingerprint(gnutls_certificate_credentials_t credentials) {
Expand Down Expand Up @@ -149,7 +381,7 @@ string make_fingerprint(gnutls_x509_crt_t crt) {
return oss.str();
}

#else // USE_GNUTLS==0
#else // OPENSSL

namespace {

Expand Down Expand Up @@ -291,7 +523,6 @@ Certificate Certificate::Generate(CertificateType type, const string &commonName
Certificate::Certificate(shared_ptr<X509> x509, shared_ptr<EVP_PKEY> pkey)
: mX509(std::move(x509)), mPKey(std::move(pkey)), mFingerprint(make_fingerprint(mX509.get())) {}

string Certificate::fingerprint() const { return mFingerprint; }

std::tuple<X509 *, EVP_PKEY *> Certificate::credentials() const {
return {mX509.get(), mPKey.get()};
Expand Down
10 changes: 9 additions & 1 deletion src/impl/certificate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ class Certificate {
#if USE_GNUTLS
Certificate(gnutls_x509_crt_t crt, gnutls_x509_privkey_t privkey);
gnutls_certificate_credentials_t credentials() const;
#else
#elif USE_MBEDTLS
Certificate(shared_ptr<mbedtls_x509_crt> crt, shared_ptr<mbedtls_pk_context> pk);
std::tuple<shared_ptr<mbedtls_x509_crt>, shared_ptr<mbedtls_pk_context>> credentials() const;
#else // OPENSSL
Certificate(shared_ptr<X509> x509, shared_ptr<EVP_PKEY> pkey);
std::tuple<X509 *, EVP_PKEY *> credentials() const;
#endif
Expand All @@ -42,6 +45,9 @@ class Certificate {
#if USE_GNUTLS
Certificate(shared_ptr<gnutls_certificate_credentials_t> creds);
const shared_ptr<gnutls_certificate_credentials_t> mCredentials;
#elif USE_MBEDTLS
const shared_ptr<mbedtls_x509_crt> mCrt;
const shared_ptr<mbedtls_pk_context> mPk;
#else
const shared_ptr<X509> mX509;
const shared_ptr<EVP_PKEY> mPKey;
Expand All @@ -53,6 +59,8 @@ class Certificate {
#if USE_GNUTLS
string make_fingerprint(gnutls_certificate_credentials_t credentials);
string make_fingerprint(gnutls_x509_crt_t crt);
#elif USE_MBEDTLS
string make_fingerprint(shared_ptr<mbedtls_x509_crt> crt);
#else
string make_fingerprint(X509 *x509);
#endif
Expand Down
36 changes: 36 additions & 0 deletions src/impl/dtlssrtptransport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,42 @@ void DtlsSrtpTransport::postHandshake() {

serverKey = reinterpret_cast<const unsigned char *>(serverKeyDatum.data);
serverSalt = reinterpret_cast<const unsigned char *>(serverSaltDatum.data);
#elif USE_MBEDTLS
PLOG_INFO << "Deriving SRTP keying material (Mbed TLS)";
unsigned int keySize = SRTP_AES_128_KEY_LEN;
unsigned int saltSize = SRTP_SALT_LEN;
auto srtpProfile = srtp_profile_aes128_cm_sha1_80;
auto keySizeWithSalt = SRTP_AES_ICM_128_KEY_LEN_WSALT;
mbedtls_dtls_srtp_info srtpInfo;

mbedtls_ssl_get_dtls_srtp_negotiation_result(&mSsl, &srtpInfo);
switch (srtpInfo.private_chosen_dtls_srtp_profile) {
case MBEDTLS_TLS_SRTP_AES128_CM_HMAC_SHA1_80:
break;
case MBEDTLS_TLS_SRTP_AES128_CM_HMAC_SHA1_32:
srtpProfile = srtp_profile_aes128_cm_sha1_32;
break;
default:
throw std::runtime_error("Failed to get SRTP profile");
}


const size_t materialLen = keySizeWithSalt * 2;
std::vector<unsigned char> material(materialLen);
// The extractor provides the client write master key, the server write master key, the client
// write master salt and the server write master salt in that order.
const string label = "EXTRACTOR-dtls_srtp";

if (mbedtls_ssl_tls_prf(tlsProfile, (const unsigned char*) masterSecret, sizeof(masterSecret), label.c_str(),
(const unsigned char*) randBytes, sizeof(randBytes), material.data(), materialLen) != 0) {
throw std::runtime_error("Failed to derive SRTP keys");
}

// Order is client key, server key, client salt, and server salt
clientKey = material.data();
serverKey = clientKey + keySize;
clientSalt = serverKey + keySize;
serverSalt = clientSalt + saltSize;
#else
PLOG_INFO << "Deriving SRTP keying material (OpenSSL)";
auto profile = SSL_get_selected_srtp_profile(mSsl);
Expand Down
Loading

0 comments on commit 5f50cbf

Please sign in to comment.