Skip to content

Commit

Permalink
[core] Fix the group option that should be taken from a socket (Haivi…
Browse files Browse the repository at this point in the history
  • Loading branch information
ethouris authored Aug 14, 2024
1 parent 430a67a commit bd071e1
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 4 deletions.
72 changes: 70 additions & 2 deletions srtcore/group.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -758,12 +758,15 @@ void CUDTGroup::getOpt(SRT_SOCKOPT optname, void* pw_optval, int& w_optlen)
w_optlen = sizeof(uint32_t);
return;

case SRTO_KMSTATE:
*(uint32_t*)pw_optval = getGroupEncryptionState();
w_optlen = sizeof(uint32_t);
return;

// Write-only options for security reasons or
// options that refer to a socket state, that
// makes no sense for a group.
case SRTO_PASSPHRASE:
case SRTO_KMSTATE:
case SRTO_PBKEYLEN:
case SRTO_KMPREANNOUNCE:
case SRTO_KMREFRESHRATE:
case SRTO_BINDTODEVICE:
Expand All @@ -775,6 +778,19 @@ void CUDTGroup::getOpt(SRT_SOCKOPT optname, void* pw_optval, int& w_optlen)
default:; // pass on
}

bool is_set_on_socket = false;
{
ScopedLock cg(m_GroupLock);
gli_t gi = m_Group.begin();
if (gi != m_Group.end())
{
// Return the value from the first member socket, if any is present
// Note: Will throw exception if the request is wrong.
gi->ps->core().getOpt(optname, (pw_optval), (w_optlen));
is_set_on_socket = true;
}
}

// Check if the option is in the storage, which means that
// it was modified on the group.

Expand All @@ -783,12 +799,18 @@ void CUDTGroup::getOpt(SRT_SOCKOPT optname, void* pw_optval, int& w_optlen)

if (i == m_config.end())
{
// Already written to the target variable.
if (is_set_on_socket)
return;

// Not found, see the defaults
if (!getOptDefault(optname, (pw_optval), (w_optlen)))
throw CUDTException(MJ_NOTSUP, MN_INVAL, 0);

return;
}
// NOTE: even if is_set_on_socket, if it was also found in the group
// settings, overwrite with the value from the group.

// Found, return the value from the storage.
// Check the size first.
Expand All @@ -799,6 +821,52 @@ void CUDTGroup::getOpt(SRT_SOCKOPT optname, void* pw_optval, int& w_optlen)
memcpy((pw_optval), &i->value[0], i->value.size());
}

SRT_KM_STATE CUDTGroup::getGroupEncryptionState()
{
multiset<SRT_KM_STATE> kmstates;
{
ScopedLock lk (m_GroupLock);

// First check the container. If empty, return UNSECURED
if (m_Group.empty())
return SRT_KM_S_UNSECURED;

for (gli_t gi = m_Group.begin(); gi != m_Group.end(); ++gi)
{
CCryptoControl* cc = gi->ps->core().m_pCryptoControl.get();
if (!cc)
continue;
SRT_KM_STATE gst = cc->m_RcvKmState;
// A fix to NOSECRET is because this is the state when agent has set
// no password, but peer did, and ENFORCEDENCRYPTION=false allowed
// this connection to be established. UNSECURED can't be taken in this
// case because this would suggest that BOTH are unsecured, that is,
// we have established an unsecured connection (which ain't true).
if (gst == SRT_KM_S_UNSECURED && cc->m_SndKmState == SRT_KM_S_NOSECRET)
gst = SRT_KM_S_NOSECRET;
kmstates.insert(gst);
}
}

// Criteria are:
// 1. UNSECURED, if no member sockets, or at least one UNSECURED found.
// 2. SECURED, if at least one SECURED found (cut off the previous criteria).
// 3. BADSECRET otherwise, although return NOSECRET if no BADSECRET is found.

if (kmstates.count(SRT_KM_S_UNSECURED))
return SRT_KM_S_UNSECURED;

// Now we have UNSECURED ruled out. Remaining may be NOSECRET, BADSECRET or SECURED.
// NOTE: SECURING is an intermediate state for HSv4 and can't occur in groups.
if (kmstates.count(SRT_KM_S_SECURED))
return SRT_KM_S_SECURED;

if (kmstates.count(SRT_KM_S_BADSECRET))
return SRT_KM_S_BADSECRET;

return SRT_KM_S_NOSECRET;
}

SRT_SOCKSTATUS CUDTGroup::getStatus()
{
typedef vector<pair<SRTSOCKET, SRT_SOCKSTATUS> > states_t;
Expand Down
2 changes: 2 additions & 0 deletions srtcore/group.h
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,8 @@ class CUDTGroup

void send_CheckValidSockets();

SRT_KM_STATE getGroupEncryptionState();

public:
int recv(char* buf, int len, SRT_MSGCTRL& w_mc);

Expand Down
27 changes: 27 additions & 0 deletions test/test_bonding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,14 @@ TEST(Bonding, Options)
string pass = "longenoughpassword";
// passphrase should be ok.
EXPECT_NE(srt_setsockflag(grp, SRTO_PASSPHRASE, pass.c_str(), pass.size()), SRT_ERROR);

uint32_t val = 16;
EXPECT_NE(srt_setsockflag(grp, SRTO_PBKEYLEN, &val, sizeof val), SRT_ERROR);

#ifdef ENABLE_AEAD_API_PREVIEW
val = 1;
EXPECT_NE(srt_setsockflag(grp, SRTO_CRYPTOMODE, &val, sizeof val), SRT_ERROR);
#endif
#endif

int lat = 500;
Expand Down Expand Up @@ -446,6 +454,25 @@ TEST(Bonding, Options)
EXPECT_EQ(optsize, sizeof ohead);
EXPECT_EQ(ohead, 12);

#if SRT_ENABLE_ENCRYPTION

uint32_t kms = -1;

EXPECT_NE(srt_getsockflag(grp, SRTO_KMSTATE, &kms, &optsize), SRT_ERROR);
EXPECT_EQ(optsize, sizeof kms);
EXPECT_EQ(kms, int(SRT_KM_S_SECURED));

EXPECT_NE(srt_getsockflag(grp, SRTO_PBKEYLEN, &kms, &optsize), SRT_ERROR);
EXPECT_EQ(optsize, sizeof kms);
EXPECT_EQ(kms, 16);

#ifdef ENABLE_AEAD_API_PREVIEW
EXPECT_NE(srt_getsockflag(grp, SRTO_CRYPTOMODE, &kms, &optsize), SRT_ERROR);
EXPECT_EQ(optsize, sizeof kms);
EXPECT_EQ(kms, 1);
#endif
#endif

// We're done, the thread can close connection and exit
{
// Make sure that the thread reached the wait() call.
Expand Down
4 changes: 2 additions & 2 deletions test/test_crypto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ namespace srt
m_crypt.setCryptoKeylen(cfg.iSndCryptoKeyLen);

cfg.iCryptoMode = CSrtConfig::CIPHER_MODE_AES_GCM;
EXPECT_EQ(m_crypt.init(HSD_INITIATOR, cfg, true), HaiCrypt_IsAESGCM_Supported() != 0);
EXPECT_TRUE(m_crypt.init(HSD_INITIATOR, cfg, true, HaiCrypt_IsAESGCM_Supported()));

const unsigned char* kmmsg = m_crypt.getKmMsg_data(0);
const size_t km_len = m_crypt.getKmMsg_size(0);
Expand All @@ -53,7 +53,7 @@ namespace srt

std::array<uint32_t, 72> km_nworder;
NtoHLA(km_nworder.data(), reinterpret_cast<const uint32_t*>(kmmsg), km_len);
m_crypt.processSrtMsg_KMREQ(km_nworder.data(), km_len, 5, kmout, kmout_len);
m_crypt.processSrtMsg_KMREQ(km_nworder.data(), km_len, 5, SrtVersion(1, 5, 3), kmout, kmout_len);
}

void TearDown() override
Expand Down

0 comments on commit bd071e1

Please sign in to comment.