diff --git a/worker/include/DepUsrSCTP.hpp b/worker/include/DepUsrSCTP.hpp index 122cb63c47..450ddeefe6 100644 --- a/worker/include/DepUsrSCTP.hpp +++ b/worker/include/DepUsrSCTP.hpp @@ -8,6 +8,14 @@ class DepUsrSCTP { +public: + struct SendSctpDataStore + { + RTC::SctpAssociation* sctpAssociation; + uint8_t* data; + size_t len; + }; + private: class Checker : public TimerHandle::Listener { @@ -37,12 +45,15 @@ class DepUsrSCTP static void RegisterSctpAssociation(RTC::SctpAssociation* sctpAssociation); static void DeregisterSctpAssociation(RTC::SctpAssociation* sctpAssociation); static RTC::SctpAssociation* RetrieveSctpAssociation(uintptr_t id); + static void SendSctpData(RTC::SctpAssociation* sctpAssociation, uint8_t* data, size_t len); + static SendSctpDataStore* GetSendSctpDataStore(uv_async_t* handle); private: thread_local static Checker* checker; static uint64_t numSctpAssociations; static uintptr_t nextSctpAssociationId; static absl::flat_hash_map mapIdSctpAssociation; + static absl::flat_hash_map mapAsyncHandlerSendSctpData; }; #endif diff --git a/worker/include/RTC/SctpAssociation.hpp b/worker/include/RTC/SctpAssociation.hpp index e2f3fe7922..d78d6a12af 100644 --- a/worker/include/RTC/SctpAssociation.hpp +++ b/worker/include/RTC/SctpAssociation.hpp @@ -6,6 +6,7 @@ #include "RTC/DataConsumer.hpp" #include "RTC/DataProducer.hpp" #include +#include namespace RTC { @@ -80,7 +81,11 @@ namespace RTC public: flatbuffers::Offset FillBuffer( flatbuffers::FlatBufferBuilder& builder) const; - void TransportConnected(); + uv_async_t* GetAsyncHandle() const + { + return this->uvAsyncHandle; + } + void InitializeSyncHandle(uv_async_cb callback); SctpState GetState() const { return this->state; @@ -89,6 +94,7 @@ namespace RTC { return this->sctpBufferedAmount; } + void TransportConnected(); void ProcessSctpData(const uint8_t* data, size_t len) const; void SendSctpMessage( RTC::DataConsumer* dataConsumer, @@ -106,7 +112,7 @@ namespace RTC /* Callbacks fired by usrsctp events. */ public: - void OnUsrSctpSendSctpData(void* buffer, size_t len); + void OnUsrSctpSendSctpData(uint8_t* data, size_t len); void OnUsrSctpReceiveSctpData( uint16_t streamId, uint16_t ssn, uint32_t ppid, int flags, const uint8_t* data, size_t len); void OnUsrSctpReceiveSctpNotification(union sctp_notification* notification, size_t len); @@ -125,6 +131,7 @@ namespace RTC size_t sctpBufferedAmount{ 0u }; bool isDataChannel{ false }; // Allocated by this. + uv_async_t* uvAsyncHandle{ nullptr }; uint8_t* messageBuffer{ nullptr }; // Others. SctpState state{ SctpState::NEW }; diff --git a/worker/src/DepUsrSCTP.cpp b/worker/src/DepUsrSCTP.cpp index 700bac3348..8b6f78d611 100644 --- a/worker/src/DepUsrSCTP.cpp +++ b/worker/src/DepUsrSCTP.cpp @@ -1,5 +1,5 @@ #define MS_CLASS "DepUsrSCTP" -// #define MS_LOG_DEV_LEVEL 3 +#define MS_LOG_DEV_LEVEL 3 #include "DepUsrSCTP.hpp" #ifdef MS_LIBURING_SUPPORTED @@ -8,7 +8,8 @@ #include "DepLibUV.hpp" #include "Logger.hpp" #include -#include // std::vsnprintf() +#include // std::vsnprintf() +#include // std::memcpy() #include /* Static. */ @@ -17,10 +18,40 @@ static constexpr size_t CheckerInterval{ 10u }; // In ms. static std::mutex GlobalSyncMutex; static size_t GlobalInstances{ 0u }; +/* Static methods for UV callbacks. */ + +inline static void onAsync(uv_async_t* handle) +{ + MS_TRACE(); + MS_DUMP("---------- onAsync!!"); + + const std::lock_guard lock(GlobalSyncMutex); + + // Get the sending data from the map. + auto* store = DepUsrSCTP::GetSendSctpDataStore(handle); + + if (!store) + { + MS_WARN_DEV("store not found"); + + return; + } + + auto* sctpAssociation = store->sctpAssociation; + auto* data = store->data; + auto len = store->len; + + MS_DUMP("---------- onAsync, sending SCTP data!!"); + + sctpAssociation->OnUsrSctpSendSctpData(data, len); +} + /* Static methods for usrsctp global callbacks. */ inline static int onSendSctpData(void* addr, void* data, size_t len, uint8_t /*tos*/, uint8_t /*setDf*/) { + MS_TRACE(); + auto* sctpAssociation = DepUsrSCTP::RetrieveSctpAssociation(reinterpret_cast(addr)); if (!sctpAssociation) @@ -30,7 +61,7 @@ inline static int onSendSctpData(void* addr, void* data, size_t len, uint8_t /*t return -1; } - sctpAssociation->OnUsrSctpSendSctpData(data, len); + DepUsrSCTP::SendSctpData(sctpAssociation, static_cast(data), len); // NOTE: Must not free data, usrsctp lib does it. @@ -60,6 +91,7 @@ thread_local DepUsrSCTP::Checker* DepUsrSCTP::checker{ nullptr }; uint64_t DepUsrSCTP::numSctpAssociations{ 0u }; uintptr_t DepUsrSCTP::nextSctpAssociationId{ 0u }; absl::flat_hash_map DepUsrSCTP::mapIdSctpAssociation; +absl::flat_hash_map DepUsrSCTP::mapAsyncHandlerSendSctpData; /* Static methods. */ @@ -91,6 +123,7 @@ void DepUsrSCTP::ClassDestroy() MS_TRACE(); const std::lock_guard lock(GlobalSyncMutex); + --GlobalInstances; if (GlobalInstances == 0) @@ -101,6 +134,7 @@ void DepUsrSCTP::ClassDestroy() nextSctpAssociationId = 0u; DepUsrSCTP::mapIdSctpAssociation.clear(); + DepUsrSCTP::mapAsyncHandlerSendSctpData.clear(); } } @@ -158,13 +192,20 @@ void DepUsrSCTP::RegisterSctpAssociation(RTC::SctpAssociation* sctpAssociation) MS_ASSERT(DepUsrSCTP::checker != nullptr, "Checker not created"); - auto it = DepUsrSCTP::mapIdSctpAssociation.find(sctpAssociation->id); + auto it = DepUsrSCTP::mapIdSctpAssociation.find(sctpAssociation->id); + auto it2 = DepUsrSCTP::mapAsyncHandlerSendSctpData.find(sctpAssociation->GetAsyncHandle()); MS_ASSERT( it == DepUsrSCTP::mapIdSctpAssociation.end(), - "the id of the SctpAssociation is already in the map"); + "the id of the SctpAssociation is already in the mapIdSctpAssociation map"); + MS_ASSERT( + it2 == DepUsrSCTP::mapAsyncHandlerSendSctpData.end(), + "the id of the SctpAssociation is already in the mapAsyncHandlerSendSctpData map"); DepUsrSCTP::mapIdSctpAssociation[sctpAssociation->id] = sctpAssociation; + DepUsrSCTP::mapAsyncHandlerSendSctpData[sctpAssociation->GetAsyncHandle()]; + + sctpAssociation->InitializeSyncHandle(onAsync); if (++DepUsrSCTP::numSctpAssociations == 1u) { @@ -180,9 +221,11 @@ void DepUsrSCTP::DeregisterSctpAssociation(RTC::SctpAssociation* sctpAssociation MS_ASSERT(DepUsrSCTP::checker != nullptr, "Checker not created"); - auto found = DepUsrSCTP::mapIdSctpAssociation.erase(sctpAssociation->id); + auto found1 = DepUsrSCTP::mapIdSctpAssociation.erase(sctpAssociation->id); + auto found2 = DepUsrSCTP::mapAsyncHandlerSendSctpData.erase(sctpAssociation->GetAsyncHandle()); - MS_ASSERT(found > 0, "SctpAssociation not found"); + MS_ASSERT(found1 > 0, "SctpAssociation not found in mapIdSctpAssociation map"); + MS_ASSERT(found2 > 0, "SctpAssociation not found in mapAsyncHandlerSendSctpData map"); MS_ASSERT(DepUsrSCTP::numSctpAssociations > 0u, "numSctpAssociations was not higher than 0"); if (--DepUsrSCTP::numSctpAssociations == 0u) @@ -207,6 +250,56 @@ RTC::SctpAssociation* DepUsrSCTP::RetrieveSctpAssociation(uintptr_t id) return it->second; } +void DepUsrSCTP::SendSctpData(RTC::SctpAssociation* sctpAssociation, uint8_t* data, size_t len) +{ + MS_TRACE(); + + const std::lock_guard lock(GlobalSyncMutex); + + // Store the sending data into the map. + + auto it = DepUsrSCTP::mapAsyncHandlerSendSctpData.find(sctpAssociation->GetAsyncHandle()); + + MS_ASSERT( + it != DepUsrSCTP::mapAsyncHandlerSendSctpData.end(), + "SctpAssociation not found in mapAsyncHandlerSendSctpData map"); + + SendSctpDataStore& store = it->second; + + // NOTE: In Rust, DepUsrSCTP::SendSctpData() is called from onSendSctpData() + // callback from a different thread and usrsctp immediately frees |data| when + // the callback execution finishes. So we have to mem copy it. + store.sctpAssociation = sctpAssociation; + store.data = new uint8_t[len]; + store.len = len; + + std::memcpy(store.data, data, len); + + // Invoke UV async send. + int err = uv_async_send(sctpAssociation->GetAsyncHandle()); + + if (err != 0) + { + MS_WARN_TAG(sctp, "uv_async_send() failed: %s", uv_strerror(err)); + } +} + +DepUsrSCTP::SendSctpDataStore* DepUsrSCTP::GetSendSctpDataStore(uv_async_t* handle) +{ + MS_TRACE(); + + auto it = DepUsrSCTP::mapAsyncHandlerSendSctpData.find(handle); + + if (it == DepUsrSCTP::mapAsyncHandlerSendSctpData.end()) + { + return nullptr; + } + + SendSctpDataStore& store = it->second; + + return std::addressof(store); +} + /* DepUsrSCTP::Checker instance methods. */ DepUsrSCTP::Checker::Checker() : timer(new TimerHandle(this)) diff --git a/worker/src/RTC/SctpAssociation.cpp b/worker/src/RTC/SctpAssociation.cpp index 499ec2a053..f64edfcf05 100644 --- a/worker/src/RTC/SctpAssociation.cpp +++ b/worker/src/RTC/SctpAssociation.cpp @@ -2,6 +2,7 @@ // #define MS_LOG_DEV_LEVEL 3 #include "RTC/SctpAssociation.hpp" +#include "DepLibUV.hpp" #include "DepUsrSCTP.hpp" #include "Logger.hpp" #include "MediaSoupErrors.hpp" @@ -121,6 +122,9 @@ namespace RTC { MS_TRACE(); + // Create a uv_async_t handle. + this->uvAsyncHandle = new uv_async_t; + // Register ourselves in usrsctp. // NOTE: This must be done before calling usrsctp_bind(). usrsctp_register_address(reinterpret_cast(this->id)); @@ -293,6 +297,7 @@ namespace RTC // Register the SctpAssociation from the global map. DepUsrSCTP::DeregisterSctpAssociation(this); + delete this->uvAsyncHandle; delete[] this->messageBuffer; } @@ -381,6 +386,18 @@ namespace RTC this->isDataChannel); } + void SctpAssociation::InitializeSyncHandle(uv_async_cb callback) + { + MS_TRACE(); + + int err = uv_async_init(DepLibUV::GetLoop(), this->uvAsyncHandle, callback); + + if (err != 0) + { + MS_ABORT("uv_async_init() failed: %s", uv_strerror(err)); + } + } + void SctpAssociation::ProcessSctpData(const uint8_t* data, size_t len) const { MS_TRACE(); @@ -667,12 +684,10 @@ namespace RTC } } - void SctpAssociation::OnUsrSctpSendSctpData(void* buffer, size_t len) + void SctpAssociation::OnUsrSctpSendSctpData(uint8_t* data, size_t len) { MS_TRACE(); - const uint8_t* data = static_cast(buffer); - #if MS_LOG_DEV_LEVEL == 3 MS_DUMP_DATA(data, len); #endif