Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the usrsctp vulnerability by using a global map for SctpAssociations #439

Merged
merged 2 commits into from
Jul 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions worker/include/DepUsrSCTP.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
#define MS_DEP_USRSCTP_HPP

#include "common.hpp"
#include "RTC/SctpAssociation.hpp"
#include "handles/Timer.hpp"
#include <unordered_map>

class DepUsrSCTP
{
Expand All @@ -29,12 +31,16 @@ class DepUsrSCTP
public:
static void ClassInit();
static void ClassDestroy();
static void IncreaseSctpAssociations();
static void DecreaseSctpAssociations();
static uintptr_t GetNextSctpAssociationId();
static void RegisterSctpAssociation(RTC::SctpAssociation* sctpAssociation);
static void DeregisterSctpAssociation(RTC::SctpAssociation* sctpAssociation);
static RTC::SctpAssociation* RetrieveSctpAssociation(uintptr_t id);

private:
static Checker* checker;
static uint64_t numSctpAssociations;
static uintptr_t nextSctpAssociationId;
static std::unordered_map<uintptr_t, RTC::SctpAssociation*> mapIdSctpAssociation;
};

#endif
3 changes: 3 additions & 0 deletions worker/include/RTC/SctpAssociation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ namespace RTC
uint16_t streamId, uint16_t ssn, uint32_t ppid, int flags, const uint8_t* data, size_t len);
void OnUsrSctpReceiveSctpNotification(union sctp_notification* notification, size_t len);

public:
uintptr_t id{ 0u };

private:
// Passed by argument.
Listener* listener{ nullptr };
Expand Down
58 changes: 54 additions & 4 deletions worker/src/DepUsrSCTP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include "DepUsrSCTP.hpp"
#include "DepLibUV.hpp"
#include "Logger.hpp"
#include "RTC/SctpAssociation.hpp"
#include <usrsctp.h>

/* Static. */
Expand All @@ -15,10 +14,14 @@ static constexpr size_t CheckerInterval{ 10u }; // In ms.

inline static int onSendSctpData(void* addr, void* data, size_t len, uint8_t /*tos*/, uint8_t /*setDf*/)
{
auto* sctpAssociation = static_cast<RTC::SctpAssociation*>(addr);
auto* sctpAssociation = DepUsrSCTP::RetrieveSctpAssociation(reinterpret_cast<uintptr_t>(addr));

if (!sctpAssociation)
{
MS_WARN_TAG(sctp, "no SctpAssociation found");

return -1;
}

sctpAssociation->OnUsrSctpSendSctpData(data, len);

Expand Down Expand Up @@ -48,6 +51,8 @@ inline static void sctpDebug(const char* format, ...)

DepUsrSCTP::Checker* DepUsrSCTP::checker{ nullptr };
uint64_t DepUsrSCTP::numSctpAssociations{ 0u };
uintptr_t DepUsrSCTP::nextSctpAssociationId{ 0u };
std::unordered_map<uintptr_t, RTC::SctpAssociation*> DepUsrSCTP::mapIdSctpAssociation;

/* Static methods. */

Expand Down Expand Up @@ -78,24 +83,69 @@ void DepUsrSCTP::ClassDestroy()
delete DepUsrSCTP::checker;
}

void DepUsrSCTP::IncreaseSctpAssociations()
uintptr_t DepUsrSCTP::GetNextSctpAssociationId()
{
MS_TRACE();

// NOTE: usrsctp_connect() fails with a value of 0.
if (DepUsrSCTP::nextSctpAssociationId == 0u)
++DepUsrSCTP::nextSctpAssociationId;

// In case we've wrapped around and need to find an empty spot from a removed
// SctpAssociation. Assumes we'll never be full.
while (DepUsrSCTP::mapIdSctpAssociation.find(DepUsrSCTP::nextSctpAssociationId) !=
DepUsrSCTP::mapIdSctpAssociation.end())
{
++DepUsrSCTP::nextSctpAssociationId;

if (DepUsrSCTP::nextSctpAssociationId == 0u)
++DepUsrSCTP::nextSctpAssociationId;
}

return DepUsrSCTP::nextSctpAssociationId++;
}

void DepUsrSCTP::RegisterSctpAssociation(RTC::SctpAssociation* sctpAssociation)
{
MS_TRACE();

auto it = DepUsrSCTP::mapIdSctpAssociation.find(sctpAssociation->id);

MS_ASSERT(
it == DepUsrSCTP::mapIdSctpAssociation.end(),
"the id of the SctpAssociation is already in the map");

DepUsrSCTP::mapIdSctpAssociation[sctpAssociation->id] = sctpAssociation;

if (++DepUsrSCTP::numSctpAssociations == 1u)
DepUsrSCTP::checker->Start();
}

void DepUsrSCTP::DecreaseSctpAssociations()
void DepUsrSCTP::DeregisterSctpAssociation(RTC::SctpAssociation* sctpAssociation)
{
MS_TRACE();

auto found = DepUsrSCTP::mapIdSctpAssociation.erase(sctpAssociation->id);

MS_ASSERT(found > 0, "SctpAssociation not found");
MS_ASSERT(DepUsrSCTP::numSctpAssociations > 0u, "numSctpAssociations was not higher than 0");

if (--DepUsrSCTP::numSctpAssociations == 0u)
DepUsrSCTP::checker->Stop();
}

RTC::SctpAssociation* DepUsrSCTP::RetrieveSctpAssociation(uintptr_t id)
{
MS_TRACE();

auto it = DepUsrSCTP::mapIdSctpAssociation.find(id);

if (it == DepUsrSCTP::mapIdSctpAssociation.end())
return nullptr;

return it->second;
}

/* DepUsrSCTP::Checker instance methods. */

DepUsrSCTP::Checker::Checker()
Expand Down
52 changes: 40 additions & 12 deletions worker/src/RTC/SctpAssociation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@ inline static int onRecvSctpData(
int flags,
void* ulpInfo)
{
auto* sctpAssociation = static_cast<RTC::SctpAssociation*>(ulpInfo);
auto* sctpAssociation = DepUsrSCTP::RetrieveSctpAssociation(reinterpret_cast<uintptr_t>(ulpInfo));

if (!sctpAssociation)
{
MS_WARN_TAG(sctp, "no SctpAssociation found");

std::free(data);

return 0;
Expand Down Expand Up @@ -99,24 +101,34 @@ namespace RTC
{
MS_TRACE();

// Get a id for this SctpAssociation.
this->id = DepUsrSCTP::GetNextSctpAssociationId();

// Register ourselves in usrsctp.
usrsctp_register_address(static_cast<void*>(this));
// NOTE: This must be done before calling usrsctp_bind().
usrsctp_register_address(reinterpret_cast<void*>(this->id));

int ret;

this->socket = usrsctp_socket(
AF_CONN, SOCK_STREAM, IPPROTO_SCTP, onRecvSctpData, nullptr, 0, static_cast<void*>(this));
AF_CONN, SOCK_STREAM, IPPROTO_SCTP, onRecvSctpData, nullptr, 0, reinterpret_cast<void*>(this->id));

if (!this->socket)
{
usrsctp_deregister_address(reinterpret_cast<void*>(this->id));

MS_THROW_ERROR("usrsctp_socket() failed: %s", std::strerror(errno));
}

usrsctp_set_ulpinfo(this->socket, static_cast<void*>(this));
usrsctp_set_ulpinfo(this->socket, reinterpret_cast<void*>(this->id));

// Make the socket non-blocking.
ret = usrsctp_set_non_blocking(this->socket, 1);

if (ret < 0)
{
usrsctp_deregister_address(reinterpret_cast<void*>(this->id));

MS_THROW_ERROR("usrsctp_set_non_blocking() failed: %s", std::strerror(errno));
}

Expand All @@ -133,6 +145,8 @@ namespace RTC

if (ret < 0)
{
usrsctp_deregister_address(reinterpret_cast<void*>(this->id));

MS_THROW_ERROR("usrsctp_setsockopt(SO_LINGER) failed: %s", std::strerror(errno));
}

Expand All @@ -146,6 +160,8 @@ namespace RTC

if (ret < 0)
{
usrsctp_deregister_address(reinterpret_cast<void*>(this->id));

MS_THROW_ERROR("usrsctp_setsockopt(SCTP_ENABLE_STREAM_RESET) failed: %s", std::strerror(errno));
}

Expand All @@ -156,6 +172,8 @@ namespace RTC

if (ret < 0)
{
usrsctp_deregister_address(reinterpret_cast<void*>(this->id));

MS_THROW_ERROR("usrsctp_setsockopt(SCTP_NODELAY) failed: %s", std::strerror(errno));
}

Expand All @@ -173,6 +191,8 @@ namespace RTC

if (ret < 0)
{
usrsctp_deregister_address(reinterpret_cast<void*>(this->id));

MS_THROW_ERROR("usrsctp_setsockopt(SCTP_EVENT) failed: %s", std::strerror(errno));
}
}
Expand All @@ -188,6 +208,8 @@ namespace RTC

if (ret < 0)
{
usrsctp_deregister_address(reinterpret_cast<void*>(this->id));

MS_THROW_ERROR("usrsctp_setsockopt(SCTP_INITMSG) failed: %s", std::strerror(errno));
}

Expand All @@ -197,26 +219,31 @@ namespace RTC
std::memset(&sconn, 0, sizeof(sconn));
sconn.sconn_family = AF_CONN;
sconn.sconn_port = htons(5000);
sconn.sconn_addr = static_cast<void*>(this);
sconn.sconn_addr = reinterpret_cast<void*>(this->id);
#ifdef HAVE_SCONN_LEN
rconn.sconn_len = sizeof(sconn);
sconn.sconn_len = sizeof(sconn);
#endif

ret = usrsctp_bind(this->socket, reinterpret_cast<struct sockaddr*>(&sconn), sizeof(sconn));

if (ret < 0)
{
usrsctp_deregister_address(reinterpret_cast<void*>(this->id));

MS_THROW_ERROR("usrsctp_bind() failed: %s", std::strerror(errno));
}

DepUsrSCTP::IncreaseSctpAssociations();

auto bufferSize = static_cast<int>(sctpSendBufferSize);

if (usrsctp_setsockopt(this->socket, SOL_SOCKET, SO_SNDBUF, &bufferSize, sizeof(int)) < 0)
{
usrsctp_deregister_address(reinterpret_cast<void*>(this->id));

MS_THROW_ERROR("usrsctp_setsockopt(SO_SNDBUF) failed: %s", std::strerror(errno));
}

// Register the SctpAssociation into the global map.
DepUsrSCTP::RegisterSctpAssociation(this);
}

SctpAssociation::~SctpAssociation()
Expand All @@ -227,9 +254,10 @@ namespace RTC
usrsctp_close(this->socket);

// Deregister ourselves from usrsctp.
usrsctp_deregister_address(static_cast<void*>(this));
usrsctp_deregister_address(reinterpret_cast<void*>(this->id));

DepUsrSCTP::DecreaseSctpAssociations();
// Register the SctpAssociation from the global map.
DepUsrSCTP::DeregisterSctpAssociation(this);

delete[] this->messageBuffer;
}
Expand All @@ -250,7 +278,7 @@ namespace RTC
std::memset(&rconn, 0, sizeof(rconn));
rconn.sconn_family = AF_CONN;
rconn.sconn_port = htons(5000);
rconn.sconn_addr = static_cast<void*>(this);
rconn.sconn_addr = reinterpret_cast<void*>(this->id);
#ifdef HAVE_SCONN_LEN
rconn.sconn_len = sizeof(rconn);
#endif
Expand Down Expand Up @@ -316,7 +344,7 @@ namespace RTC
MS_DUMP_DATA(data, len);
#endif

usrsctp_conninput(static_cast<void*>(this), data, len, 0);
usrsctp_conninput(reinterpret_cast<void*>(this->id), data, len, 0);
}

void SctpAssociation::SendSctpMessage(
Expand Down