Skip to content

Commit

Permalink
Changed stream list to arrays and optimized stream lookup with SSE2.
Browse files Browse the repository at this point in the history
The stream list in the SRTP context is now implemented with two arrays:
an array of SSRCs and an array of pointers to the streams corresponding
to the SSRCs. The streams no longer form a linked list.

Stream lookup by SSRC is now performed over the array of SSRCs, which
is considerably faster because it is more cache-friendly. Additionally,
the lookup is optimized for SSE2, which provides an additional massive
speedup with many streams in the list. Although the lookup still has
linear complexity, its absolute times are reduced and with tens to
hundreds elements are lower or comparable with a typical rb-tree
equivalent.

Expected speedup of SSE2 version over the previous implementation:

SSRCs    speedup (scalar)   speedup (SSE2)

1        0.39x              0.22x
3        0.57x              0.23x
5        0.69x              0.62x
10       0.77x              1.43x
20       0.86x              2.38x
30       0.87x              3.44x
50       1.13x              6.21x
100      1.25x              8.51x
200      1.30x              9.83x

These numbers were obtained on a Core i7 2600K.

At small numbers of SSRCs the new algorithm is somewhat slower, but
given that the absolute and relative times of the lookup are very small,
that slowdown is not very significant.
  • Loading branch information
Lastique committed Feb 28, 2023
1 parent 0d8b2d7 commit 1235205
Showing 1 changed file with 198 additions and 34 deletions.
232 changes: 198 additions & 34 deletions srtp/srtp.c
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,22 @@
#include "aes_icm_ext.h"
#endif

#include <stddef.h>
#include <string.h>
#include <limits.h>
#ifdef HAVE_NETINET_IN_H
#include <netinet/in.h>
#elif defined(HAVE_WINSOCK2_H)
#include <winsock2.h>
#endif

#if defined(__SSE2__)
#include <emmintrin.h>
#if defined(_MSC_VER)
#include <intrin.h>
#endif
#endif

/* the debug module for srtp */
srtp_debug_module_t mod_srtp = {
0, /* debugging is off by default */
Expand All @@ -79,6 +88,17 @@ srtp_debug_module_t mod_srtp = {
#define uint32s_in_rtcp_header 2
#define octets_in_rtp_extn_hdr 4

#ifndef SRTP_NO_STREAM_LIST
static inline uint32_t srtp_stream_list_size(srtp_stream_list_t list);
static srtp_err_status_t srtp_stream_list_reserve(srtp_stream_list_t list,
uint32_t new_capacity);
static uint32_t srtp_stream_list_find(srtp_stream_list_t list,
uint32_t ssrc);
static inline srtp_stream_t srtp_stream_list_get_at(srtp_stream_list_t list,
uint32_t pos);
static void srtp_stream_list_remove_at(srtp_stream_list_t list, uint32_t pos);
#endif // SRTP_NO_STREAM_LIST

static srtp_err_status_t srtp_validate_rtp_header(void *rtp_hdr,
int *pkt_octet_len)
{
Expand Down Expand Up @@ -3030,18 +3050,31 @@ srtp_err_status_t srtp_remove_stream(srtp_t session, uint32_t ssrc)
{
srtp_stream_ctx_t *stream;
srtp_err_status_t status;
#if !defined(SRTP_NO_STREAM_LIST)
uint32_t pos;
#endif

/* sanity check arguments */
if (session == NULL)
if (session == NULL) {
return srtp_err_status_bad_param;
}

/* find and remove stream from the list */
#if !defined(SRTP_NO_STREAM_LIST)
pos = srtp_stream_list_find(session->stream_list, ssrc);
if (pos >= srtp_stream_list_size(session->stream_list))
return srtp_err_status_no_ctx;

stream = srtp_stream_list_get_at(session->stream_list, pos);
srtp_stream_list_remove_at(session->stream_list, pos);
#else
stream = srtp_stream_list_get(session->stream_list, ssrc);
if (stream == NULL) {
return srtp_err_status_no_ctx;
}

srtp_stream_list_remove(session->stream_list, stream);
#endif

/* deallocate the stream */
status = srtp_stream_dealloc(stream, session->stream_template);
Expand Down Expand Up @@ -4840,11 +4873,11 @@ srtp_err_status_t srtp_get_stream_roc(srtp_t session,

#ifndef SRTP_NO_STREAM_LIST

/* in the default implementation, we have an intrusive doubly-linked list */
typedef struct srtp_stream_list_ctx_t_ {
/* a stub stream that just holds pointers to the beginning and end of the
* list */
srtp_stream_ctx_t data;
uint32_t *ssrcs;
srtp_stream_ctx_t **streams;
uint32_t size;
uint32_t capacity;
} srtp_stream_list_ctx_t_;

srtp_err_status_t srtp_stream_list_alloc(srtp_stream_list_t *list_ptr)
Expand All @@ -4855,73 +4888,204 @@ srtp_err_status_t srtp_stream_list_alloc(srtp_stream_list_t *list_ptr)
return srtp_err_status_alloc_fail;
}

list->data.next = NULL;
list->data.prev = NULL;

*list_ptr = list;
return srtp_err_status_ok;
}

srtp_err_status_t srtp_stream_list_dealloc(srtp_stream_list_t list)
{
/* list must be empty */
if (list->data.next) {
if (list->size != 0u) {
return srtp_err_status_fail;
}
srtp_crypto_free(list->streams);
srtp_crypto_free(list->ssrcs);
srtp_crypto_free(list);
return srtp_err_status_ok;
}

static inline uint32_t srtp_stream_list_size(srtp_stream_list_t list)
{
return list->size;
}

static srtp_err_status_t srtp_stream_list_reserve(srtp_stream_list_t list,
uint32_t new_capacity)
{
if (new_capacity > list->capacity) {
uint32_t *ssrcs;
srtp_stream_ctx_t **stream_ptrs;

if (new_capacity > (UINT32_MAX - 15u))
return srtp_err_status_alloc_fail;

new_capacity = (new_capacity + 15u) & ~((uint32_t)15u);

ssrcs = (uint32_t *)srtp_crypto_alloc((size_t)new_capacity *
sizeof(uint32_t));
if (!ssrcs)
return srtp_err_status_alloc_fail;
stream_ptrs = (srtp_stream_ctx_t **)srtp_crypto_alloc(
(size_t)new_capacity * sizeof(srtp_stream_ctx_t *));
if (!stream_ptrs) {
srtp_crypto_free(ssrcs);
return srtp_err_status_alloc_fail;
}

if (list->size > 0u) {
memcpy(ssrcs, list->ssrcs, (size_t)list->size * sizeof(uint32_t));
memcpy(stream_ptrs, list->streams,
(size_t)list->size * sizeof(srtp_stream_ctx_t *));
}

srtp_crypto_free(list->ssrcs);
srtp_crypto_free(list->streams);
list->streams = stream_ptrs;
list->ssrcs = ssrcs;

list->capacity = new_capacity;
}

return srtp_err_status_ok;
}

srtp_err_status_t srtp_stream_list_insert(srtp_stream_list_t list,
srtp_stream_t stream)
{
/* insert at the head of the list */
stream->next = list->data.next;
if (stream->next != NULL) {
stream->next->prev = stream;
}
list->data.next = stream;
stream->prev = &(list->data);
uint32_t pos;
srtp_err_status_t status = srtp_stream_list_reserve(list, list->size + 1u);
if (status)
return status;
pos = list->size++;
list->ssrcs[pos] = stream->ssrc;
list->streams[pos] = stream;

return srtp_err_status_ok;
}

srtp_stream_t srtp_stream_list_get(srtp_stream_list_t list, uint32_t ssrc)
static uint32_t srtp_stream_list_find(srtp_stream_list_t list, uint32_t ssrc)
{
/* walk down list until ssrc is found */
srtp_stream_t stream = list->data.next;
while (stream != NULL) {
if (stream->ssrc == ssrc) {
return stream;
#if defined(__SSE2__)
const uint32_t *const ssrcs = list->ssrcs;
const __m128i mm_ssrc = _mm_set1_epi32(ssrc);
uint32_t pos = 0u, n = (list->size + 7u) & ~(uint32_t)(7u);
for (uint32_t m = n & ~(uint32_t)(15u); pos < m; pos += 16u) {
__m128i mm1 = _mm_loadu_si128((const __m128i *)(ssrcs + pos));
__m128i mm2 = _mm_loadu_si128((const __m128i *)(ssrcs + pos + 4u));
__m128i mm3 = _mm_loadu_si128((const __m128i *)(ssrcs + pos + 8u));
__m128i mm4 = _mm_loadu_si128((const __m128i *)(ssrcs + pos + 12u));
mm1 = _mm_cmpeq_epi32(mm1, mm_ssrc);
mm2 = _mm_cmpeq_epi32(mm2, mm_ssrc);
mm3 = _mm_cmpeq_epi32(mm3, mm_ssrc);
mm4 = _mm_cmpeq_epi32(mm4, mm_ssrc);
mm1 = _mm_packs_epi32(mm1, mm2);
mm3 = _mm_packs_epi32(mm3, mm4);
mm1 = _mm_packs_epi16(mm1, mm3);
uint32_t mask = _mm_movemask_epi8(mm1);
if (mask) {
#if defined(_MSC_VER)
unsigned long bit_pos;
_BitScanForward(&bit_pos, mask);
pos += bit_pos;
#else
pos += __builtin_ctz(mask);
#endif

goto done;
}
}

if (pos < n) {
__m128i mm1 = _mm_loadu_si128((const __m128i *)(ssrcs + pos));
__m128i mm2 = _mm_loadu_si128((const __m128i *)(ssrcs + pos + 4u));
mm1 = _mm_cmpeq_epi32(mm1, mm_ssrc);
mm2 = _mm_cmpeq_epi32(mm2, mm_ssrc);
mm1 = _mm_packs_epi32(mm1, mm2);

uint32_t mask = _mm_movemask_epi8(mm1);
if (mask) {
#if defined(_MSC_VER)
unsigned long bit_pos;
_BitScanForward(&bit_pos, mask);
pos += bit_pos / 2u;
#else
pos += __builtin_ctz(mask) / 2u;
#endif
goto done;
}
stream = stream->next;

pos += 8u;
}

done:
return pos;
#else
/* walk down list until ssrc is found */
uint32_t pos = 0u, n = list->size;
for (; pos < n; ++pos) {
if (list->ssrcs[pos] == ssrc)
break;
}

return pos;
#endif
}

static inline srtp_stream_t srtp_stream_list_get_at(srtp_stream_list_t list,
uint32_t pos)
{
return list->streams[pos];
}

srtp_stream_t srtp_stream_list_get(srtp_stream_list_t list, uint32_t ssrc)
{
uint32_t pos = srtp_stream_list_find(list, ssrc);
if (pos < list->size)
return list->streams[pos];

/* we haven't found our ssrc, so return a null */
return NULL;
}

void srtp_stream_list_remove(srtp_stream_list_t list,
srtp_stream_t stream_to_remove)
static void srtp_stream_list_remove_at(srtp_stream_list_t list, uint32_t pos)
{
(void)list;
uint32_t tail_size, last_pos;

stream_to_remove->prev->next = stream_to_remove->next;
if (stream_to_remove->next != NULL) {
stream_to_remove->next->prev = stream_to_remove->prev;
last_pos = --list->size;
tail_size = last_pos - pos;
if (tail_size > 0u) {
memmove(list->streams + pos, list->streams + pos + 1,
(size_t)tail_size * sizeof(*list->streams));
memmove(list->ssrcs + pos, list->ssrcs + pos + 1,
(size_t)tail_size * sizeof(*list->ssrcs));
}

list->streams[last_pos] = NULL;
list->ssrcs[last_pos] = 0u;
}

void srtp_stream_list_remove(srtp_stream_list_t list,
srtp_stream_t stream_to_remove)
{
uint32_t pos = srtp_stream_list_find(list, stream_to_remove->ssrc);
if (pos < list->size)
srtp_stream_list_remove_at(list, pos);
}

void srtp_stream_list_for_each(srtp_stream_list_t list,
int (*callback)(srtp_stream_t, void *),
void *data)
{
srtp_stream_t stream = list->data.next;
while (stream != NULL) {
srtp_stream_t tmp = stream;
stream = stream->next;
if (callback(tmp, data))
uint32_t size = list->size;
for (uint32_t i = 0u; i < size;) {
if (callback(list->streams[i], data))
break;

/* check if the callback removed the current element */
if (size == list->size)
++i;
else
size = list->size;
}
}

Expand Down

0 comments on commit 1235205

Please sign in to comment.