From 289204bf39f3b079c2b1cbbaacbe83939295e1bd Mon Sep 17 00:00:00 2001 From: Sean DuBois Date: Wed, 22 Feb 2023 12:06:57 +0100 Subject: [PATCH] Implement Mbed TLS Backend Co-authored-by: tytan652 <17492366+tytan652@users.noreply.github.com> Co-authored-by: Paul-Louis Ageneau --- .github/workflows/build-mbedtls.yml | 37 ++++ CMakeLists.txt | 20 +- cmake/Modules/FindMbedTLS.cmake | 214 ++++++++++++++++++++ src/impl/certificate.cpp | 175 +++++++++++++++- src/impl/certificate.hpp | 10 +- src/impl/dtlssrtptransport.cpp | 38 +++- src/impl/dtlstransport.cpp | 299 +++++++++++++++++++++++++++- src/impl/dtlstransport.hpp | 28 ++- src/impl/init.cpp | 2 + src/impl/sha.cpp | 15 +- src/impl/tls.cpp | 76 ++++++- src/impl/tls.hpp | 27 ++- src/impl/tlstransport.cpp | 213 +++++++++++++++++++- src/impl/tlstransport.hpp | 15 ++ src/impl/verifiedtlstransport.cpp | 6 +- 15 files changed, 1154 insertions(+), 21 deletions(-) create mode 100644 .github/workflows/build-mbedtls.yml create mode 100644 cmake/Modules/FindMbedTLS.cmake diff --git a/.github/workflows/build-mbedtls.yml b/.github/workflows/build-mbedtls.yml new file mode 100644 index 000000000..ca203d394 --- /dev/null +++ b/.github/workflows/build-mbedtls.yml @@ -0,0 +1,37 @@ +name: Build with Mbed TLS +on: + push: + branches: + - master + pull_request: +jobs: + build-linux: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Set up Homebrew + uses: Homebrew/actions/setup-homebrew@master + - name: Install Mbed TLS + run: brew update && brew install mbedtls + - name: submodules + run: git submodule update --init --recursive --depth 1 + - name: cmake + run: cmake -B build -DUSE_MBEDTLS=1 -DWARNINGS_AS_ERRORS=1 -DCMAKE_PREFIX_PATH=$(brew --prefix mbedtls) + - name: make + run: (cd build; make -j2) + - name: test + run: ./build/tests + build-macos: + runs-on: macos-latest + steps: + - uses: actions/checkout@v2 + - name: Install Mbed TLS + run: brew update && brew install mbedtls + - name: submodules + run: git submodule update --init --recursive --depth 1 + - name: cmake + run: cmake -B build -DUSE_MBEDTLS=1 -DWARNINGS_AS_ERRORS=1 -DENABLE_LOCAL_ADDRESS_TRANSLATION=1 -DCMAKE_PREFIX_PATH=$(brew --prefix mbedtls) + - name: make + run: (cd build; make -j2) + - name: test + run: ./build/tests diff --git a/CMakeLists.txt b/CMakeLists.txt index 2dc0ce181..8fff915b9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,6 +5,7 @@ project(libdatachannel set(PROJECT_DESCRIPTION "C/C++ WebRTC network library featuring Data Channels, Media Transport, and WebSockets") # Options +option(USE_MBEDTLS "Use Mbed TLS instead of OpenSSL" OFF) option(USE_GNUTLS "Use GnuTLS instead of OpenSSL" OFF) option(USE_NICE "Use libnice instead of libjuice" OFF) option(PREFER_SYSTEM_LIB "Prefer system libraries over deps folder" OFF) @@ -21,12 +22,22 @@ option(WARNINGS_AS_ERRORS "Treat warnings as errors" OFF) option(CAPI_STDCALL "Set calling convention of C API callbacks stdcall" OFF) option(SCTP_DEBUG "Enable SCTP debugging output to verbose log" OFF) +if (USE_MBEDTLS AND USE_GNUTLS) + message(FATAL_ERROR "Both USE_MBEDTLS and USE_GNUTLS can not be enabled at the same time") +endif() + + if(USE_GNUTLS) option(USE_NETTLE "Use Nettle in libjuice" ON) else() option(USE_NETTLE "Use Nettle in libjuice" OFF) + if(NOT USE_SYSTEM_SRTP) - option(ENABLE_OPENSSL "Enable OpenSSL crypto engine for SRTP" ON) + if (USE_MBEDTLS) + option(ENABLE_MBEDTLS "Enable Mbed TLS crypto engine for SRTP" ON) + else() + option(ENABLE_OPENSSL "Enable OpenSSL crypto engine for SRTP" ON) + endif() endif() endif() @@ -337,6 +348,13 @@ if (USE_GNUTLS) target_link_libraries(datachannel PRIVATE Nettle::Nettle) target_link_libraries(datachannel-static PRIVATE Nettle::Nettle) endif() +elseif(USE_MBEDTLS) + find_package(MbedTLS 3 REQUIRED) + + target_compile_definitions(datachannel PRIVATE USE_MBEDTLS) + target_compile_definitions(datachannel-static PRIVATE USE_MBEDTLS) + target_link_libraries(datachannel PRIVATE MbedTLS::MbedTLS) + target_link_libraries(datachannel-static PRIVATE MbedTLS::MbedTLS) else() if(APPLE) # This is a bug in CMake that causes it to prefer the system version over diff --git a/cmake/Modules/FindMbedTLS.cmake b/cmake/Modules/FindMbedTLS.cmake new file mode 100644 index 000000000..f7b561887 --- /dev/null +++ b/cmake/Modules/FindMbedTLS.cmake @@ -0,0 +1,214 @@ +#[=======================================================================[.rst +FindMbedTLS +----------- + +FindModule for MbedTLS and associated libraries + +Components +^^^^^^^^^^ + +This module contains provides several components: + +``MbedCrypto`` +``MbedTLS`` +``MbedX509`` + +Import targets exist for each component. + +Imported Targets +^^^^^^^^^^^^^^^^ + +This module defines the :prop_tgt:`IMPORTED` targets: + +``MbedTLS::MbedCrypto`` + Crypto component + +``MbedTLS::MbedTLS`` + TLS component + +``MbedTLS::MbedX509`` + X509 component + +Result Variables +^^^^^^^^^^^^^^^^ + +This module sets the following variables: + +``MbedTLS_FOUND`` + True, if all required components and the core library were found. +``MbedTLS_VERSION`` + Detected version of found MbedTLS libraries. + +``MbedTLS__VERSION`` + Detected version of found MbedTLS component library. + +Cache variables +^^^^^^^^^^^^^^^ + +The following cache variables may also be set: + +``MbedTLS__LIBRARY`` + Path to the library component of MbedTLS. +``MbedTLS__INCLUDE_DIR`` + Directory containing ``.h``. + +Distributed under the MIT License, see accompanying LICENSE file or +https://github.com/PatTheMav/cmake-finders/blob/master/LICENSE for details. +(c) 2023 Patrick Heyer + +#]=======================================================================] + +# cmake-format: off +# cmake-lint: disable=C0103 +# cmake-lint: disable=C0301 +# cmake-lint: disable=C0307 +# cmake-format: on + +include(FindPackageHandleStandardArgs) + +find_package(PkgConfig QUIET) +if(PKG_CONFIG_FOUND) + pkg_check_modules(PC_MbedTLS QUIET mbedtls mbedcrypto mbedx509) +endif() + +# MbedTLS_set_soname: Set SONAME on imported library targets +macro(MbedTLS_set_soname component) + if(CMAKE_HOST_SYSTEM_NAME MATCHES "Darwin") + execute_process( + COMMAND sh -c "otool -D '${Mbed${component}_LIBRARY}' | grep -v '${Mbed${component}_LIBRARY}'" + OUTPUT_VARIABLE _output + RESULT_VARIABLE _result) + + if(_result EQUAL 0 AND _output MATCHES "^@rpath/") + set_property(TARGET MbedTLS::Mbed${component} PROPERTY IMPORTED_SONAME "${_output}") + endif() + elseif(CMAKE_HOST_SYSTEM_NAME MATCHES "Linux|FreeBSD") + execute_process( + COMMAND sh -c "objdump -p '${Mbed${component}_LIBRARY}' | grep SONAME" + OUTPUT_VARIABLE _output + RESULT_VARIABLE _result) + + if(_result EQUAL 0) + string(REGEX REPLACE "[ \t]+SONAME[ \t]+([^ \t]+)" "\\1" _soname "${_output}") + set_property(TARGET MbedTLS::Mbed${component} PROPERTY IMPORTED_SONAME "${_soname}") + unset(_soname) + endif() + endif() + unset(_output) + unset(_result) +endmacro() + +find_path( + MbedTLS_INCLUDE_DIR + NAMES mbedtls/ssl.h + HINTS "${PC_MbedTLS_INCLUDE_DIRS}" + PATHS /usr/include /usr/local/include + DOC "MbedTLS include directory") + +if(PC_MbedTLS_VERSION VERSION_GREATER 0) + set(MbedTLS_VERSION ${PC_MbedTLS_VERSION}) +elseif(EXISTS "${MbedTLS_INCLUDE_DIR}/mbedtls/build_info.h") + file(STRINGS "${MbedTLS_INCLUDE_DIR}/mbedtls/build_info.h" _VERSION_STRING + REGEX "#define[ \t]+MBEDTLS_VERSION_STRING[ \t]+.+") + string(REGEX REPLACE ".*#define[ \t]+MBEDTLS_VERSION_STRING[ \t]+\"(.+)\".*" "\\1" MbedTLS_VERSION + "${_VERSION_STRING}") +else() + if(NOT MbedTLS_FIND_QUIETLY) + message(AUTHOR_WARNING "Failed to find MbedTLS version.") + endif() + set(MbedTLS_VERSION 0.0.0) +endif() + +find_library( + MbedTLS_LIBRARY + NAMES libmbedtls mbedtls + HINTS "${PC_MbedTLS_LIBRARY_DIRS}" + PATHS /usr/lib /usr/local/lib + DOC "MbedTLS location") + +find_library( + MbedCrypto_LIBRARY + NAMES libmbedcrypto mbedcrypto + HINTS "${PC_MbedTLS_LIBRARY_DIRS}" + PATHS /usr/lib /usr/local/lib + DOC "MbedCrypto location") + +find_library( + MbedX509_LIBRARY + NAMES libmbedx509 mbedx509 + HINTS "${PC_MbedTLS_LIBRARY_DIRS}" + PATHS /usr/lib /usr/local/lib + DOC "MbedX509 location") + +if(MbedTLS_LIBRARY + AND NOT MbedCrypto_LIBRARY + AND NOT MbedX509_LIBRARY) + set(CMAKE_REQUIRED_LIBRARIES "${MbedTLS_LIBRARY}") + set(CMAKE_REQUIRED_INCLUDES "${MbedTLS_INCLUDE_DIR}") + + check_symbol_exists(mbedtls_x509_crt_init "mbedtls/x590_crt.h" MbedTLS_INCLUDES_X509) + check_symbol_exists(mbedtls_sha256_init "mbedtls/sha256.h" MbedTLS_INCLUDES_CRYPTO) + unset(CMAKE_REQUIRED_LIBRARIES) + unset(CMAKE_REQUIRED_INCLUDES) +endif() + +if(CMAKE_HOST_SYSTEM_NAME MATCHES "Darwin|Windows") + set(MbedTLS_ERROR_REASON "Ensure that an MbedTLS distribution is provided as part of CMAKE_PREFIX_PATH.") +elseif(CMAKE_HOST_SYSTEM_NAME MATCHES "Linux|FreeBSD") + set(MbedTLS_ERROR_REASON "Ensure that MbedTLS is installed on the system.") +endif() + +if(MbedTLS_INCLUDES_X509 AND MbedTLS_INCLUDES_CRYPTO) + find_package_handle_standard_args( + MbedTLS + REQUIRED_VARS MbedTLS_LIBRARY MbedTLS_INCLUDE_DIR + VERSION_VAR MbedTLS_VERSION REASON_FAILURE_MESSAGE "${MbedTLS_ERROR_REASON}") + mark_as_advanced(MbedTLS_LIBRARY MbedTLS_INCLUDE_DIR) + list(APPEND _COMPONENTS TLS) +else() + find_package_handle_standard_args( + MbedTLS + REQUIRED_VARS MbedTLS_LIBRARY MbedCrypto_LIBRARY MbedX509_LIBRARY MbedTLS_INCLUDE_DIR + VERSION_VAR MbedTLS_VERSION REASON_FAILURE_MESSAGE "${MbedTLS_ERROR_REASON}") + mark_as_advanced(MbedTLS_LIBRARY MbedCrypto_LIBRARY MbedX509_LIBRARY MbedTLS_INCLUDE_DIR) + list(APPEND _COMPONENTS TLS Crypto X509) +endif() +unset(MbedTLS_ERROR_REASON) + +if(MbedTLS_FOUND) + foreach(component IN LISTS _COMPONENTS) + if(NOT TARGET MbedTLS::Mbed${component}) + if(IS_ABSOLUTE "${Mbed${component}_LIBRARY}") + add_library(MbedTLS::Mbed${component} UNKNOWN IMPORTED) + set_property(TARGET MbedTLS::Mbed${component} PROPERTY IMPORTED_LOCATION "${Mbed${component}_LIBRARY}") + else() + add_library(MbedTLS::Mbed${component} INTERFACE IMPORTED) + set_property(TARGET MbedTLS::Mbed${component} PROPERTY IMPORTED_LIBNAME "${Mbed${component}_LIBRARY}") + endif() + + mbedtls_set_soname(${component}) + set_target_properties( + MbedTLS::MbedTLS + PROPERTIES INTERFACE_COMPILE_OPTIONS "${PC_MbedTLS_CFLAGS_OTHER}" + INTERFACE_INCLUDE_DIRECTORIES "${MbedTLS_INCLUDE_DIR}" + VERSION ${MbedTLS_VERSION}) + endif() + endforeach() + + if(MbedTLS_INCLUDES_X509 AND MbedTLS_INCLUDES_CRYPTO) + set(MbedTLS_LIBRARIES ${MbedTLS_LIBRARY}) + set(MBEDTLS_INCLUDE_DIRS ${MbedTLS_INCLUDE_DIR}) + else() + set(MbedTLS_LIBRARIES ${MbedTLS_LIBRARY} ${MbedCrypto_LIBRARY} ${MbedX509_LIBRARY}) + set_property(TARGET MbedTLS::MbedTLS PROPERTY INTERFACE_LINK_LIBRARIES MbedTLS::MbedCrypto MbedTLS::MbedX509) + set(MBEDTLS_INCLUDE_DIRS ${MbedTLS_INCLUDE_DIR}) + endif() +endif() + +include(FeatureSummary) +set_package_properties( + MbedTLS PROPERTIES + URL "https://www.trustedfirmware.org/projects/mbed-tls" + DESCRIPTION + "A C library implementing cryptographic primitives, X.509 certificate manipulation, and the SSL/TLS and DTLS protocols." +) diff --git a/src/impl/certificate.cpp b/src/impl/certificate.cpp index 9b4fa26c3..6d773d837 100644 --- a/src/impl/certificate.cpp +++ b/src/impl/certificate.cpp @@ -111,8 +111,6 @@ Certificate::Certificate(shared_ptr creds) gnutls_certificate_credentials_t Certificate::credentials() const { return *mCredentials; } -string Certificate::fingerprint() const { return mFingerprint; } - string make_fingerprint(gnutls_certificate_credentials_t credentials) { auto new_crt_list = [credentials]() -> gnutls_x509_crt_t * { gnutls_x509_crt_t *crt_list = nullptr; @@ -149,7 +147,172 @@ string make_fingerprint(gnutls_x509_crt_t crt) { return oss.str(); } -#else // USE_GNUTLS==0 +#elif USE_MBEDTLS +string make_fingerprint(shared_ptr crt) { + const int size = 32; + uint8_t buffer[size]; + std::stringstream fingerprint; + + mbedtls::check( + mbedtls_sha256(crt->raw.p, crt->raw.len, reinterpret_cast(buffer), 0), + "Failed to generate certificate fingerprint"); + + for (auto i = 0; i < size; i++) { + fingerprint << std::setfill('0') << std::setw(2) << std::hex << static_cast(buffer[i]); + if (i != (size - 1)) { + fingerprint << ":"; + } + } + + return fingerprint.str(); +} + +Certificate::Certificate(shared_ptr crt, shared_ptr pk) + : mCrt(crt), mPk(pk), mFingerprint(make_fingerprint(crt)) {} + +Certificate Certificate::FromString(string crt_pem, string key_pem) { + PLOG_DEBUG << "Importing certificate from PEM string (MbedTLS)"; + + auto crt = mbedtls::new_x509_crt(); + auto pk = mbedtls::new_pk_context(); + + mbedtls::check(mbedtls_x509_crt_parse(crt.get(), + reinterpret_cast(crt_pem.c_str()), + crt_pem.length()), + "Failed to parse certificate"); + mbedtls::check(mbedtls_pk_parse_key(pk.get(), + reinterpret_cast(key_pem.c_str()), + key_pem.size(), NULL, 0, NULL, 0), + "Failed to parse key"); + + return Certificate(std::move(crt), std::move(pk)); +} + +Certificate Certificate::FromFile(const string &crt_pem_file, const string &key_pem_file, + const string &pass) { + PLOG_DEBUG << "Importing certificate from PEM file (MbedTLS): " << crt_pem_file; + + auto crt = mbedtls::new_x509_crt(); + auto pk = mbedtls::new_pk_context(); + + mbedtls::check(mbedtls_x509_crt_parse_file(crt.get(), crt_pem_file.c_str()), + "Failed to parse certificate"); + mbedtls::check(mbedtls_pk_parse_keyfile(pk.get(), key_pem_file.c_str(), pass.c_str(), 0, NULL), + "Failed to parse key"); + + return Certificate(std::move(crt), std::move(pk)); +} + +Certificate Certificate::Generate(CertificateType type, const string &commonName) { + PLOG_DEBUG << "Generating certificate (MbedTLS)"; + + mbedtls_entropy_context entropy; + mbedtls_ctr_drbg_context drbg; + mbedtls_x509write_cert wcrt; + mbedtls_mpi serial; + auto crt = mbedtls::new_x509_crt(); + auto pk = mbedtls::new_pk_context(); + + mbedtls_entropy_init(&entropy); + mbedtls_ctr_drbg_init(&drbg); + mbedtls_ctr_drbg_set_prediction_resistance(&drbg, MBEDTLS_CTR_DRBG_PR_ON); + mbedtls_x509write_crt_init(&wcrt); + mbedtls_mpi_init(&serial); + + try { + mbedtls::check(mbedtls_ctr_drbg_seed( + &drbg, mbedtls_entropy_func, &entropy, + reinterpret_cast(commonName.data()), commonName.size())); + + switch (type) { + // RFC 8827 WebRTC Security Architecture 6.5. Communications Security + // All implementations MUST support DTLS 1.2 with the + // TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 cipher suite and the P-256 curve + // See https://www.rfc-editor.org/rfc/rfc8827.html#section-6.5 + case CertificateType::Default: + case CertificateType::Ecdsa: { + mbedtls::check(mbedtls_pk_setup(pk.get(), mbedtls_pk_info_from_type(MBEDTLS_PK_ECKEY))); + mbedtls::check(mbedtls_ecp_gen_key(MBEDTLS_ECP_DP_SECP256R1, mbedtls_pk_ec(*pk.get()), + mbedtls_ctr_drbg_random, &drbg), + "Unable to generate ECDSA P-256 key pair"); + break; + } + case CertificateType::Rsa: { + const unsigned int nbits = 2048; + const int exponent = 65537; + + mbedtls::check(mbedtls_pk_setup(pk.get(), mbedtls_pk_info_from_type(MBEDTLS_PK_RSA))); + mbedtls::check(mbedtls_rsa_gen_key(mbedtls_pk_rsa(*pk.get()), mbedtls_ctr_drbg_random, + &drbg, nbits, exponent), + "Unable to generate RSA key pair"); + break; + } + default: + throw std::invalid_argument("Unknown certificate type"); + } + + auto now = std::chrono::system_clock::now(); + string notBefore = mbedtls::format_time(now - std::chrono::hours(1)); + string notAfter = mbedtls::format_time(now + std::chrono::hours(24 * 365)); + + const size_t serialBufferSize = 16; + unsigned char serialBuffer[serialBufferSize]; + mbedtls::check(mbedtls_ctr_drbg_random(&drbg, serialBuffer, serialBufferSize), + "Failed to generate certificate"); + mbedtls::check(mbedtls_mpi_read_binary(&serial, serialBuffer, serialBufferSize), + "Failed to generate certificate"); + + std::string name = std::string("O=" + commonName + ",CN=" + commonName); + mbedtls::check(mbedtls_x509write_crt_set_serial(&wcrt, &serial), + "Failed to generate certificate"); + mbedtls::check(mbedtls_x509write_crt_set_subject_name(&wcrt, name.c_str()), + "Failed to generate certificate"); + mbedtls::check(mbedtls_x509write_crt_set_issuer_name(&wcrt, name.c_str()), + "Failed to generate certificate"); + mbedtls::check( + mbedtls_x509write_crt_set_validity(&wcrt, notBefore.c_str(), notAfter.c_str()), + "Failed to generate certificate"); + + mbedtls_x509write_crt_set_version(&wcrt, MBEDTLS_X509_CRT_VERSION_3); + mbedtls_x509write_crt_set_subject_key(&wcrt, pk.get()); + mbedtls_x509write_crt_set_issuer_key(&wcrt, pk.get()); + mbedtls_x509write_crt_set_md_alg(&wcrt, MBEDTLS_MD_SHA256); + + const size_t certificateBufferSize = 4096; + unsigned char certificateBuffer[certificateBufferSize]; + std::memset(certificateBuffer, 0, certificateBufferSize); + + auto certificateLen = mbedtls_x509write_crt_der( + &wcrt, certificateBuffer, certificateBufferSize, mbedtls_ctr_drbg_random, &drbg); + if (certificateLen <= 0) { + throw std::runtime_error("Certificate generation failed"); + } + + mbedtls::check(mbedtls_x509_crt_parse_der( + crt.get(), (certificateBuffer + certificateBufferSize - certificateLen), + certificateLen), + "Failed to generate certificate"); + } catch (...) { + mbedtls_entropy_free(&entropy); + mbedtls_ctr_drbg_free(&drbg); + mbedtls_x509write_crt_free(&wcrt); + mbedtls_mpi_free(&serial); + throw; + } + + mbedtls_entropy_free(&entropy); + mbedtls_ctr_drbg_free(&drbg); + mbedtls_x509write_crt_free(&wcrt); + mbedtls_mpi_free(&serial); + return Certificate(std::move(crt), std::move(pk)); +} + +std::tuple, shared_ptr> +Certificate::credentials() const { + return {mCrt, mPk}; +} + +#else // OPENSSL namespace { @@ -291,8 +454,6 @@ Certificate Certificate::Generate(CertificateType type, const string &commonName Certificate::Certificate(shared_ptr x509, shared_ptr pkey) : mX509(std::move(x509)), mPKey(std::move(pkey)), mFingerprint(make_fingerprint(mX509.get())) {} -string Certificate::fingerprint() const { return mFingerprint; } - std::tuple Certificate::credentials() const { return {mX509.get(), mPKey.get()}; } @@ -316,7 +477,7 @@ string make_fingerprint(X509 *x509) { #endif -// Common for GnuTLS and OpenSSL +// Common for GnuTLS, Mbed TLS, and OpenSSL future_certificate_ptr make_certificate(CertificateType type) { return ThreadPool::Instance().enqueue([type, token = Init::Instance().token()]() { @@ -324,4 +485,6 @@ future_certificate_ptr make_certificate(CertificateType type) { }); } +string Certificate::fingerprint() const { return mFingerprint; } + } // namespace rtc::impl diff --git a/src/impl/certificate.hpp b/src/impl/certificate.hpp index 564acb571..363632652 100644 --- a/src/impl/certificate.hpp +++ b/src/impl/certificate.hpp @@ -29,7 +29,10 @@ class Certificate { #if USE_GNUTLS Certificate(gnutls_x509_crt_t crt, gnutls_x509_privkey_t privkey); gnutls_certificate_credentials_t credentials() const; -#else +#elif USE_MBEDTLS + Certificate(shared_ptr crt, shared_ptr pk); + std::tuple, shared_ptr> credentials() const; +#else // OPENSSL Certificate(shared_ptr x509, shared_ptr pkey); std::tuple credentials() const; #endif @@ -42,6 +45,9 @@ class Certificate { #if USE_GNUTLS Certificate(shared_ptr creds); const shared_ptr mCredentials; +#elif USE_MBEDTLS + const shared_ptr mCrt; + const shared_ptr mPk; #else const shared_ptr mX509; const shared_ptr mPKey; @@ -53,6 +59,8 @@ class Certificate { #if USE_GNUTLS string make_fingerprint(gnutls_certificate_credentials_t credentials); string make_fingerprint(gnutls_x509_crt_t crt); +#elif USE_MBEDTLS +string make_fingerprint(shared_ptr crt); #else string make_fingerprint(X509 *x509); #endif diff --git a/src/impl/dtlssrtptransport.cpp b/src/impl/dtlssrtptransport.cpp index 00a77fa82..41d50f5d8 100644 --- a/src/impl/dtlssrtptransport.cpp +++ b/src/impl/dtlssrtptransport.cpp @@ -213,7 +213,7 @@ bool DtlsSrtpTransport::demuxMessage(message_ptr message) { } else { COUNTER_UNKNOWN_PACKET_TYPE++; PLOG_DEBUG << "Unknown packet type, value=" << unsigned(value1) - << ", size=" << message->size(); + << ", size=" << message->size(); return true; } } @@ -263,12 +263,46 @@ void DtlsSrtpTransport::postHandshake() { serverKey = reinterpret_cast(serverKeyDatum.data); serverSalt = reinterpret_cast(serverSaltDatum.data); +#elif USE_MBEDTLS + PLOG_INFO << "Deriving SRTP keying material (Mbed TLS)"; + unsigned int keySize = SRTP_AES_128_KEY_LEN; + unsigned int saltSize = SRTP_SALT_LEN; + auto srtpProfile = srtp_profile_aes128_cm_sha1_80; + auto keySizeWithSalt = SRTP_AES_ICM_128_KEY_LEN_WSALT; + mbedtls_dtls_srtp_info srtpInfo; + + mbedtls_ssl_get_dtls_srtp_negotiation_result(&mSsl, &srtpInfo); + if (srtpInfo.private_chosen_dtls_srtp_profile != MBEDTLS_TLS_SRTP_AES128_CM_HMAC_SHA1_80) { + throw std::runtime_error("Failed to get SRTP profile"); + } + + const size_t materialLen = keySizeWithSalt * 2; + std::vector material(materialLen); + // The extractor provides the client write master key, the server write master key, the client + // write master salt and the server write master salt in that order. + const string label = "EXTRACTOR-dtls_srtp"; + + if (mTlsProfile == MBEDTLS_SSL_TLS_PRF_NONE) { + throw std::logic_error("Failed to get SRTP profile"); + } + + if (mbedtls_ssl_tls_prf(mTlsProfile, reinterpret_cast(mMasterSecret), 32, + label.c_str(), reinterpret_cast(mRandBytes), 32, + material.data(), materialLen) != 0) { + throw std::runtime_error("Failed to derive SRTP keys"); + } + + // Order is client key, server key, client salt, and server salt + clientKey = material.data(); + serverKey = clientKey + keySize; + clientSalt = serverKey + keySize; + serverSalt = clientSalt + saltSize; #else PLOG_INFO << "Deriving SRTP keying material (OpenSSL)"; auto profile = SSL_get_selected_srtp_profile(mSsl); if (!profile) throw std::runtime_error("Failed to get SRTP profile: " + - openssl::error_string(ERR_get_error())); + openssl::error_string(ERR_get_error())); PLOG_DEBUG << "srtp profile used is: " << profile->name; auto [keySize, saltSize, srtpProfile] = getEncryptionParams(profile->name); auto keySizeWithSalt = keySize + saltSize; diff --git a/src/impl/dtlstransport.cpp b/src/impl/dtlstransport.cpp index 9b676ea57..55c50ece7 100644 --- a/src/impl/dtlstransport.cpp +++ b/src/impl/dtlstransport.cpp @@ -360,7 +360,300 @@ int DtlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int /* m } } -#else // USE_GNUTLS==0 +#elif USE_MBEDTLS + +mbedtls_ssl_srtp_profile srtpSupportedProtectionProfiles[] = { + MBEDTLS_TLS_SRTP_AES128_CM_HMAC_SHA1_80, + MBEDTLS_TLS_SRTP_UNSET, +}; + +DtlsTransport::DtlsTransport(shared_ptr lower, certificate_ptr certificate, + optional mtu, verifier_callback verifierCallback, + state_callback stateChangeCallback) + : Transport(lower, std::move(stateChangeCallback)), mMtu(mtu), mCertificate(certificate), + mVerifierCallback(std::move(verifierCallback)), + mIsClient(lower->role() == Description::Role::Active) { + + PLOG_DEBUG << "Initializing DTLS transport (MbedTLS)"; + + if (!mCertificate) + throw std::invalid_argument("DTLS certificate is null"); + + mbedtls_entropy_init(&mEntropy); + mbedtls_ctr_drbg_init(&mDrbg); + mbedtls_ssl_init(&mSsl); + mbedtls_ssl_config_init(&mConf); + mbedtls_ctr_drbg_set_prediction_resistance(&mDrbg, MBEDTLS_CTR_DRBG_PR_ON); + + try { + mbedtls::check(mbedtls_ctr_drbg_seed(&mDrbg, mbedtls_entropy_func, &mEntropy, NULL, 0), + "Failed creating Mbed TLS Context"); + + mbedtls::check(mbedtls_ssl_config_defaults( + &mConf, mIsClient ? MBEDTLS_SSL_IS_CLIENT : MBEDTLS_SSL_IS_SERVER, + MBEDTLS_SSL_TRANSPORT_DATAGRAM, MBEDTLS_SSL_PRESET_DEFAULT), + "Failed creating Mbed TLS Context"); + + mbedtls_ssl_conf_authmode(&mConf, MBEDTLS_SSL_VERIFY_OPTIONAL); + mbedtls_ssl_conf_rng(&mConf, mbedtls_ctr_drbg_random, &mDrbg); + + auto [crt, pk] = mCertificate->credentials(); + mbedtls::check(mbedtls_ssl_conf_own_cert(&mConf, crt.get(), pk.get()), + "Failed creating Mbed TLS Context"); + + mbedtls_ssl_conf_dtls_cookies(&mConf, NULL, NULL, NULL); + mbedtls_ssl_conf_dtls_srtp_protection_profiles(&mConf, srtpSupportedProtectionProfiles); + + mbedtls::check(mbedtls_ssl_setup(&mSsl, &mConf), "Failed creating Mbed TLS Context"); + + size_t mtu = mMtu.value_or(DEFAULT_MTU) - 8 - 40; // UDP/IPv6 + mbedtls_ssl_set_mtu(&mSsl, static_cast(mtu)); + PLOG_VERBOSE << "DTLS MTU set to " << mtu; + + mbedtls_ssl_set_export_keys_cb(&mSsl, DtlsTransport::ExportKeysCallback, this); + mbedtls_ssl_set_bio(&mSsl, this, WriteCallback, ReadCallback, NULL); + mbedtls_ssl_set_timer_cb(&mSsl, this, SetTimerCallback, GetTimerCallback); + } catch (...) { + mbedtls_entropy_free(&mEntropy); + mbedtls_ctr_drbg_free(&mDrbg); + mbedtls_ssl_free(&mSsl); + mbedtls_ssl_config_free(&mConf); + throw; + } + + // Set recommended medium-priority DSCP value for handshake + // See https://www.rfc-editor.org/rfc/rfc8837.html#section-5 + mCurrentDscp = 10; // AF11: Assured Forwarding class 1, low drop probability +} + +DtlsTransport::~DtlsTransport() { + stop(); + + PLOG_DEBUG << "Destroying DTLS transport"; + mbedtls_entropy_free(&mEntropy); + mbedtls_ctr_drbg_free(&mDrbg); + mbedtls_ssl_free(&mSsl); + mbedtls_ssl_config_free(&mConf); +} + +void DtlsTransport::Init() { + // Nothing to do +} + +void DtlsTransport::Cleanup() { + // Nothing to do +} + +void DtlsTransport::start() { + PLOG_DEBUG << "Starting DTLS transport"; + registerIncoming(); + changeState(State::Connecting); + + enqueueRecv(); // to initiate the handshake +} + +void DtlsTransport::stop() { + PLOG_DEBUG << "Stopping DTLS transport"; + unregisterIncoming(); + mIncomingQueue.stop(); + enqueueRecv(); +} + +bool DtlsTransport::send(message_ptr message) { + if (!message || state() != State::Connected) + return false; + + PLOG_VERBOSE << "Send size=" << message->size(); + + int ret; + do { + std::lock_guard lock(mMutex); + mCurrentDscp = message->dscp; + + if (message->size() > size_t(mbedtls_ssl_get_max_out_record_payload(&mSsl))) + return false; + + ret = mbedtls_ssl_write(&mSsl, reinterpret_cast(message->data()), + message->size()); + } while (ret == MBEDTLS_ERR_SSL_WANT_WRITE); + mbedtls::check(ret); + + return mOutgoingResult; +} + +void DtlsTransport::incoming(message_ptr message) { + if (!message) { + mIncomingQueue.stop(); + return; + } + + PLOG_VERBOSE << "Incoming size=" << message->size(); + mIncomingQueue.push(message); + enqueueRecv(); +} + +bool DtlsTransport::outgoing(message_ptr message) { + message->dscp = mCurrentDscp; + + bool result = Transport::outgoing(std::move(message)); + mOutgoingResult = result; + return result; +} + +bool DtlsTransport::demuxMessage(message_ptr) { + // Dummy + return false; +} + +void DtlsTransport::postHandshake() { + // Dummy +} + +void DtlsTransport::doRecv() { + std::lock_guard lock(mRecvMutex); + --mPendingRecvCount; + + if (state() != State::Connecting && state() != State::Connected) + return; + + try { + const size_t bufferSize = 4096; + char buffer[bufferSize]; + + // Handle handshake if connecting + if (state() == State::Connecting) { + auto ret = mbedtls_ssl_handshake(&mSsl); + + if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE || + ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS || + ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) { + ThreadPool::Instance().schedule(mTimerSetAt + milliseconds(mFinMs), [weak_this = weak_from_this()]() { + if (auto locked = weak_this.lock()) + locked->doRecv(); + }); + return; + } + + mbedtls::check(ret, "DTLS handshake failed"); + PLOG_INFO << "DTLS handshake finished"; + changeState(State::Connected); + postHandshake(); + } + + if (state() == State::Connected) { + while (true) { + mMutex.lock(); + auto ret = + mbedtls_ssl_read(&mSsl, reinterpret_cast(buffer), bufferSize); + mMutex.unlock(); + + if (ret == 0 || ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) { + // Closed + PLOG_DEBUG << "DTLS connection cleanly closed"; + break; + } + + if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE || + ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS || + ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) { + return; + } + mbedtls::check(ret); + + auto *b = reinterpret_cast(buffer); + recv(make_message(b, b + ret)); + } + } + } catch (const std::exception &e) { + PLOG_ERROR << "DTLS recv: " << e.what(); + } + + PLOG_INFO << "DTLS closed"; + changeState(State::Disconnected); + recv(nullptr); +} + +void DtlsTransport::ExportKeysCallback(void *ctx, mbedtls_ssl_key_export_type /*type*/, + const unsigned char *secret, size_t secret_len, + const unsigned char client_random[32], + const unsigned char server_random[32], + mbedtls_tls_prf_types tls_prf_type) { + auto dtlsTransport = static_cast(ctx); + std::memcpy(dtlsTransport->mMasterSecret, secret, secret_len); + std::memcpy(dtlsTransport->mRandBytes, client_random, 32); + std::memcpy(dtlsTransport->mRandBytes + 32, server_random, 32); + dtlsTransport->mTlsProfile = tls_prf_type; +} + +int DtlsTransport::WriteCallback(void *ctx, const unsigned char *buf, size_t len) { + auto *t = static_cast(ctx); + try { + if (len > 0) { + auto b = reinterpret_cast(buf); + t->outgoing(make_message(b, b + len)); + } + return int(len); + + } catch (const std::exception &e) { + PLOG_WARNING << e.what(); + return MBEDTLS_ERR_SSL_INTERNAL_ERROR; + } +} + +int DtlsTransport::ReadCallback(void *ctx, unsigned char *buf, size_t len) { + auto *t = static_cast(ctx); + try { + while (t->mIncomingQueue.running()) { + auto next = t->mIncomingQueue.pop(); + if (!next) { + return MBEDTLS_ERR_SSL_WANT_READ; + } + + message_ptr message = std::move(*next); + if (t->demuxMessage(message)) + continue; + + auto bufMin = std::min(len, size_t(message->size())); + std::memcpy(buf, message->data(), bufMin); + return int(len); + } + + // Closed + return 0; + + } catch (const std::exception &e) { + PLOG_WARNING << e.what(); + return MBEDTLS_ERR_SSL_INTERNAL_ERROR; + ; + } +} + +void DtlsTransport::SetTimerCallback(void *ctx, uint32_t int_ms, uint32_t fin_ms) { + auto dtlsTransport = static_cast(ctx); + dtlsTransport->mIntMs = int_ms; + dtlsTransport->mFinMs = fin_ms; + + if (fin_ms != 0) { + dtlsTransport->mTimerSetAt = std::chrono::steady_clock::now(); + } +} + +int DtlsTransport::GetTimerCallback(void *ctx) { + auto dtlsTransport = static_cast(ctx); + auto now = std::chrono::steady_clock::now(); + + if (dtlsTransport->mFinMs == 0) { + return -1; + } else if (now >= dtlsTransport->mTimerSetAt + milliseconds(dtlsTransport->mFinMs)) { + return 2; + } else if (now >= dtlsTransport->mTimerSetAt + milliseconds(dtlsTransport->mIntMs)) { + return 1; + } else { + return 0; + } +} + +#else // OPENSSL BIO_METHOD *DtlsTransport::BioMethods = NULL; int DtlsTransport::TransportExIndex = -1; @@ -415,8 +708,8 @@ DtlsTransport::DtlsTransport(shared_ptr lower, certificate_ptr cer SSL_CTX_set_min_proto_version(mCtx, DTLS1_VERSION); SSL_CTX_set_read_ahead(mCtx, 1); - //sent the dtls close_notify alert - //SSL_CTX_set_quiet_shutdown(mCtx, 1); + // sent the dtls close_notify alert + // SSL_CTX_set_quiet_shutdown(mCtx, 1); SSL_CTX_set_info_callback(mCtx, InfoCallback); SSL_CTX_set_verify(mCtx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, diff --git a/src/impl/dtlstransport.hpp b/src/impl/dtlstransport.hpp index d8a185a51..673464fea 100644 --- a/src/impl/dtlstransport.hpp +++ b/src/impl/dtlstransport.hpp @@ -70,7 +70,33 @@ class DtlsTransport : public Transport, public std::enable_shared_from_this mTimerSetAt; + + char mMasterSecret[48]; + char mRandBytes[64]; + mbedtls_tls_prf_types mTlsProfile = MBEDTLS_SSL_TLS_PRF_NONE; + + static int WriteCallback(void *ctx, const unsigned char *buf, size_t len); + static int ReadCallback(void *ctx, unsigned char *buf, size_t len); + static void ExportKeysCallback(void *ctx, mbedtls_ssl_key_export_type type, + const unsigned char *secret, size_t secret_len, + const unsigned char client_random[32], + const unsigned char server_random[32], + mbedtls_tls_prf_types tls_prf_type); + static void SetTimerCallback(void *ctx, uint32_t int_ms, uint32_t fin_ms); + static int GetTimerCallback(void *ctx); + +#else // OPENSSL SSL_CTX *mCtx = NULL; SSL *mSsl = NULL; BIO *mInBio, *mOutBio; diff --git a/src/impl/init.cpp b/src/impl/init.cpp index 4527acc46..a926c572c 100644 --- a/src/impl/init.cpp +++ b/src/impl/init.cpp @@ -128,6 +128,8 @@ void Init::doInit() { #if USE_GNUTLS // Nothing to do +#elif USE_MBEDTLS + // Nothing to do #else openssl::init(); #endif diff --git a/src/impl/sha.cpp b/src/impl/sha.cpp index cb3f0424d..c93626e1d 100644 --- a/src/impl/sha.cpp +++ b/src/impl/sha.cpp @@ -14,7 +14,11 @@ #include -#else // USE_GNUTLS==0 +#elif USE_MBEDTLS + +#include + +#else #ifndef OPENSSL_API_COMPAT #define OPENSSL_API_COMPAT 0x10100000L @@ -38,7 +42,14 @@ binary Sha1(const byte *data, size_t size) { sha1_digest(&ctx, SHA1_DIGEST_SIZE, reinterpret_cast(output.data())); return output; -#else // USE_GNUTLS==0 +#elif USE_MBEDTLS + + binary output(20); + mbedtls_sha1(reinterpret_cast(data), size, + reinterpret_cast(output.data())); + return output; + +#else binary output(SHA_DIGEST_LENGTH); SHA_CTX ctx; diff --git a/src/impl/tls.cpp b/src/impl/tls.cpp index 832c5c9a2..c18494ead 100644 --- a/src/impl/tls.cpp +++ b/src/impl/tls.cpp @@ -70,7 +70,81 @@ gnutls_datum_t make_datum(char *data, size_t size) { } // namespace rtc::gnutls -#else // USE_GNUTLS==0 +#elif USE_MBEDTLS + +#include + +namespace { + +// Safe gmtime +int my_gmtime(const time_t *t, struct tm *buf) { +#ifdef _WIN32 + return ::gmtime_s(buf, t) == 0 ? 0 : -1; +#else // POSIX + return ::gmtime_r(t, buf) != NULL ? 0 : -1; +#endif +} + +// Format time_t as UTC +size_t my_strftme(char *buf, size_t size, const char *format, const time_t *t) { + struct tm g; + if (my_gmtime(t, &g) != 0) + return 0; + + return ::strftime(buf, size, format, &g); +} + +} // namespace + +namespace rtc::mbedtls { + +void check(int ret, const string &message) { + if (ret < 0) { + const size_t bufferSize = 1024; + char buffer[bufferSize]; + mbedtls_strerror(ret, reinterpret_cast(buffer), bufferSize); + PLOG_ERROR << message << ": " << buffer; + throw std::runtime_error(message + ": " + std::string(buffer)); + } +} + +string format_time(const std::chrono::system_clock::time_point &tp) { + time_t t = std::chrono::system_clock::to_time_t(tp); + const size_t bufferSize = 256; + char buffer[bufferSize]; + if (my_strftme(buffer, bufferSize, "%Y%m%d%H%M%S", &t) == 0) + throw std::runtime_error("Time conversion failed"); + + return string(buffer); +}; + +std::shared_ptr new_pk_context() { + return std::shared_ptr{[]() { + auto p = new mbedtls_pk_context; + mbedtls_pk_init(p); + return p; + }(), + [](mbedtls_pk_context *p) { + mbedtls_pk_free(p); + delete p; + }}; +} + +std::shared_ptr new_x509_crt() { + return std::shared_ptr{[]() { + auto p = new mbedtls_x509_crt; + mbedtls_x509_crt_init(p); + return p; + }(), + [](mbedtls_x509_crt *crt) { + mbedtls_x509_crt_free(crt); + delete crt; + }}; +} + +} // namespace rtc::mbedtls + +#else // OPENSSL namespace rtc::openssl { diff --git a/src/impl/tls.hpp b/src/impl/tls.hpp index 242c3ab68..cf2e967ec 100644 --- a/src/impl/tls.hpp +++ b/src/impl/tls.hpp @@ -11,6 +11,8 @@ #include "common.hpp" +#include + #if USE_GNUTLS #include @@ -36,7 +38,30 @@ gnutls_datum_t make_datum(char *data, size_t size); } // namespace rtc::gnutls -#else // USE_GNUTLS==0 +#elif USE_MBEDTLS + +#include "mbedtls/ctr_drbg.h" +#include "mbedtls/ecdsa.h" +#include "mbedtls/entropy.h" +#include "mbedtls/error.h" +#include "mbedtls/pk.h" +#include "mbedtls/rsa.h" +#include "mbedtls/sha256.h" +#include "mbedtls/ssl.h" +#include "mbedtls/x509_crt.h" + +namespace rtc::mbedtls { + +void check(int ret, const string &message = "MbedTLS error"); + +string format_time(const std::chrono::system_clock::time_point &tp); + +std::shared_ptr new_pk_context(); +std::shared_ptr new_x509_crt(); + +} // namespace rtc::mbedtls + +#else // OPENSSL #ifdef _WIN32 // Include winsock2.h header first since OpenSSL may include winsock.h diff --git a/src/impl/tlstransport.cpp b/src/impl/tlstransport.cpp index b410453a7..c82836cca 100644 --- a/src/impl/tlstransport.cpp +++ b/src/impl/tlstransport.cpp @@ -296,7 +296,218 @@ int TlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int /* ms } } -#else // USE_GNUTLS==0 +#elif USE_MBEDTLS + +void TlsTransport::Init() { + // Nothing to do +} + +void TlsTransport::Cleanup() { + // Nothing to do +} + +TlsTransport::TlsTransport(variant, shared_ptr> lower, + optional host, certificate_ptr certificate, + state_callback callback) + : Transport(std::visit([](auto l) { return std::static_pointer_cast(l); }, lower), + std::move(callback)), + mHost(std::move(host)), mIsClient(std::visit([](auto l) { return l->isActive(); }, lower)), + mIncomingQueue(RECV_QUEUE_LIMIT, message_size_func) { + + PLOG_DEBUG << "Initializing TLS transport (MbedTLS)"; + + mbedtls_entropy_init(&mEntropy); + mbedtls_ctr_drbg_init(&mDrbg); + mbedtls_ssl_init(&mSsl); + mbedtls_ssl_config_init(&mConf); + mbedtls_ctr_drbg_set_prediction_resistance(&mDrbg, MBEDTLS_CTR_DRBG_PR_ON); + + try { + mbedtls::check(mbedtls_ctr_drbg_seed(&mDrbg, mbedtls_entropy_func, &mEntropy, NULL, 0)); + + mbedtls::check(mbedtls_ssl_config_defaults( + &mConf, mIsClient ? MBEDTLS_SSL_IS_CLIENT : MBEDTLS_SSL_IS_SERVER, + MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT)); + + mbedtls_ssl_conf_authmode(&mConf, MBEDTLS_SSL_VERIFY_OPTIONAL); + mbedtls_ssl_conf_rng(&mConf, mbedtls_ctr_drbg_random, &mDrbg); + + if (certificate) { + auto [crt, pk] = certificate->credentials(); + mbedtls::check(mbedtls_ssl_conf_own_cert(&mConf, crt.get(), pk.get())); + } + + mbedtls::check(mbedtls_ssl_setup(&mSsl, &mConf)); + mbedtls_ssl_set_bio(&mSsl, static_cast(this), WriteCallback, ReadCallback, NULL); + } catch (...) { + mbedtls_entropy_free(&mEntropy); + mbedtls_ctr_drbg_free(&mDrbg); + mbedtls_ssl_free(&mSsl); + mbedtls_ssl_config_free(&mConf); + throw; + } +} + +TlsTransport::~TlsTransport() {} + +void TlsTransport::start() { + PLOG_DEBUG << "Starting TLS transport"; + registerIncoming(); + changeState(State::Connecting); + enqueueRecv(); // to initiate the handshake +} + +void TlsTransport::stop() { + PLOG_DEBUG << "Stopping TLS transport"; + unregisterIncoming(); + mIncomingQueue.stop(); + enqueueRecv(); +} + +bool TlsTransport::send(message_ptr message) { + if (state() != State::Connected) + throw std::runtime_error("TLS is not open"); + + if (!message || message->size() == 0) + return outgoing(message); // pass through + + PLOG_VERBOSE << "Send size=" << message->size(); + + mbedtls::check(mbedtls_ssl_write( + &mSsl, reinterpret_cast(message->data()), int(message->size()))); + + return mOutgoingResult; +} + +void TlsTransport::incoming(message_ptr message) { + if (!message) { + mIncomingQueue.stop(); + enqueueRecv(); + return; + } + + PLOG_VERBOSE << "Incoming size=" << message->size(); + mIncomingQueue.push(message); + enqueueRecv(); +} + +bool TlsTransport::outgoing(message_ptr message) { + bool result = Transport::outgoing(std::move(message)); + mOutgoingResult = result; + return result; +} + +void TlsTransport::postHandshake() { + // Dummy +} + +void TlsTransport::doRecv() { + std::lock_guard lock(mRecvMutex); + --mPendingRecvCount; + + if (state() != State::Connecting && state() != State::Connected) + return; + + try { + const size_t bufferSize = 4096; + char buffer[bufferSize]; + + // Handle handshake if connecting + if (state() == State::Connecting) { + while (true) { + auto ret = mbedtls_ssl_handshake(&mSsl); + if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE || + ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS || + ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) { + continue; + } + + mbedtls::check(ret); + PLOG_INFO << "TLS handshake finished"; + changeState(State::Connected); + postHandshake(); + break; + } + } + + if (state() == State::Connected) { + while (true) { + auto ret = + mbedtls_ssl_read(&mSsl, reinterpret_cast(buffer), bufferSize); + + if (ret == 0 || ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) { + // Closed + PLOG_DEBUG << "TLS connection cleanly closed"; + break; + } + + if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE || + ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS || + ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) { + return; + } + mbedtls::check(ret); + + auto *b = reinterpret_cast(buffer); + recv(make_message(b, b + ret)); + } + } + } catch (const std::exception &e) { + PLOG_ERROR << "TLS recv: " << e.what(); + } + + PLOG_INFO << "TLS closed"; + changeState(State::Disconnected); + recv(nullptr); +} + +int TlsTransport::WriteCallback(void *ctx, const unsigned char *buf, size_t len) { + auto *t = static_cast(ctx); + auto *b = reinterpret_cast(buf); + t->outgoing(make_message(b, b + len)); + + return int(len); +} + +int TlsTransport::ReadCallback(void *ctx, unsigned char *buf, size_t len) { + TlsTransport *t = static_cast(ctx); + try { + message_ptr &message = t->mIncomingMessage; + size_t &position = t->mIncomingMessagePosition; + + if (message && position >= message->size()) + message.reset(); + + if (!message) { + position = 0; + while (auto next = t->mIncomingQueue.pop()) { + message = *next; + if (message->size() > 0) + break; + else + t->recv(message); // Pass zero-sized messages through + } + } + + if (message) { + size_t available = message->size() - position; + size_t writeLen = std::min(len, available); + std::memcpy(buf, message->data() + position, writeLen); + position += writeLen; + return int(writeLen); + } else if (t->mIncomingQueue.running()) { + return MBEDTLS_ERR_SSL_WANT_READ; + } else { + return MBEDTLS_ERR_SSL_CONN_EOF; + } + + } catch (const std::exception &e) { + PLOG_WARNING << e.what(); + return MBEDTLS_ERR_SSL_INTERNAL_ERROR; + } +} + +#else int TlsTransport::TransportExIndex = -1; diff --git a/src/impl/tlstransport.hpp b/src/impl/tlstransport.hpp index 7672b7854..145f8edf1 100644 --- a/src/impl/tlstransport.hpp +++ b/src/impl/tlstransport.hpp @@ -65,6 +65,21 @@ class TlsTransport : public Transport, public std::enable_shared_from_this mOutgoingResult = true; + + mbedtls_entropy_context mEntropy; + mbedtls_ctr_drbg_context mDrbg; + mbedtls_ssl_config mConf; + mbedtls_ssl_context mSsl; + + message_ptr mIncomingMessage; + size_t mIncomingMessagePosition = 0; + + static int WriteCallback(void *ctx, const unsigned char *buf, size_t len); + static int ReadCallback(void *ctx, unsigned char *buf, size_t len); + #else SSL_CTX *mCtx; SSL *mSsl; diff --git a/src/impl/verifiedtlstransport.cpp b/src/impl/verifiedtlstransport.cpp index 5e42d5d27..14973229d 100644 --- a/src/impl/verifiedtlstransport.cpp +++ b/src/impl/verifiedtlstransport.cpp @@ -18,11 +18,13 @@ VerifiedTlsTransport::VerifiedTlsTransport( certificate_ptr certificate, state_callback callback) : TlsTransport(std::move(lower), std::move(host), std::move(certificate), std::move(callback)) { -#if USE_GNUTLS PLOG_DEBUG << "Setting up TLS certificate verification"; + +#if USE_GNUTLS gnutls_session_set_verify_cert(mSession, mHost->c_str(), 0); +#elif USE_MBEDTLS + mbedtls_ssl_conf_authmode(&mConf, MBEDTLS_SSL_VERIFY_REQUIRED); #else - PLOG_DEBUG << "Setting up TLS certificate verification"; SSL_set_verify(mSsl, SSL_VERIFY_PEER, NULL); SSL_set_verify_depth(mSsl, 4); #endif