diff --git a/src/lib/lwan-websocket.c b/src/lib/lwan-websocket.c index b83c6950f..603b44ebf 100644 --- a/src/lib/lwan-websocket.c +++ b/src/lib/lwan-websocket.c @@ -150,63 +150,90 @@ static size_t get_frame_length(struct lwan_request *request, uint16_t header) static void unmask(char *msg, size_t msg_len, char mask[static 4]) { - const int32_t mask32 = (int32_t)string_as_uint32(mask); - const char *msg_end = msg + msg_len; + /* TODO: handle alignment of `msg` to use (at least) NT loads + * as we're rewriting msg anyway. (NT writes aren't that + * useful as the unmasked value will be used right after.) */ #if defined(__AVX2__) - const size_t len256 = msg_len / 32; - if (len256) { - const __m256i mask256 = _mm256_setr_epi32( - mask32, mask32, mask32, mask32, mask32, mask32, mask32, mask32); - for (size_t i = 0; i < len256; i++) { - __m256i v = _mm256_loadu_si256((__m256i *)msg); + const __m256i mask256 = + _mm256_castps_si256(_mm256_broadcast_ss((const float *)mask)); + if (msg_len >= 32) { + do { + __m256i v = _mm256_lddqu_si256((const __m256i *)msg); _mm256_storeu_si256((__m256i *)msg, _mm256_xor_si256(v, mask256)); - msg += 32; - } - msg_len = (size_t)(msg_end - msg); + msg += 32; + msg_len -= 32; + } while (msg_len >= 32); } #endif #if defined(__SSE2__) - const size_t len128 = msg_len / 16; - if (len128) { - const __m128i mask128 = _mm_setr_epi32(mask32, mask32, mask32, mask32); - for (size_t i = 0; i < len128; i++) { - __m128i v = _mm_loadu_si128((__m128i *)msg); +#if defined(__AVX2__) + const __m128i mask128 = _mm256_extracti128_si256(mask256, 0); +#elif defined(__SSE3__) + const __m128i mask128 = _mm_lddqu_si128((const float *)mask); +#else + const __m128i mask128 = _mm_loadu_si128((const __m128i *)mask); +#endif + if (msg_len >= 16) { + do { +#if defined(__SSE3__) + __m128i v = _mm_lddqu_si128((const __m128i *)msg); +#else + __m128i v = _mm_loadu_si128((const __m128i *)msg); +#endif + _mm_storeu_si128((__m128i *)msg, _mm_xor_si128(v, mask128)); - msg += 16; - } - msg_len = (size_t)(msg_end - msg); + msg += 16; + msg_len -= 16; + } while (msg_len >= 16); } #endif if (sizeof(void *) == 8) { - const uint64_t mask64 = (uint64_t)mask32 << 32 | (uint64_t)mask32; - const size_t len64 = msg_len / 8; - for (size_t i = 0; i < len64; i++) { - uint64_t v = string_as_uint64(msg); - v ^= mask64; - msg = mempcpy(msg, &v, sizeof(v)); + if (msg_len >= 8) { +#if defined(__SSE_4_1__) + /* We're far away enough from the AVX2 path that it's + * probably better to use mask128 instead of mask256 + * here. */ + const __int64 mask64 = _mm_extract_epi64(mask128, 0); +#else + const uint32_t mask32 = string_as_uint32(mask); + const uint64_t mask64 = (uint64_t)mask32 << 32 | (uint64_t)mask32; +#endif + do { + uint64_t v = string_as_uint64(msg); + v ^= (uint64_t)mask64; + msg = mempcpy(msg, &v, sizeof(v)); + msg_len -= 8; + } while (msg_len >= 8); } } - const size_t len32 = (size_t)((msg_end - msg) / 4); - for (size_t i = 0; i < len32; i++) { - uint32_t v = string_as_uint32(msg); - v ^= (uint32_t)mask32; - msg = mempcpy(msg, &v, sizeof(v)); + if (msg_len >= 4) { + const uint32_t mask32 = string_as_uint32(mask); + do { + uint32_t v = string_as_uint32(msg); + v ^= (uint32_t)mask32; + msg = mempcpy(msg, &v, sizeof(v)); + msg_len -= 4; + } while (msg_len >= 4); } - switch (msg_end - msg) { + switch (msg_len) { case 3: msg[2] ^= mask[2]; /* fallthrough */ case 2: msg[1] ^= mask[1]; /* fallthrough */ case 1: msg[0] ^= mask[0]; + break; + default: + __builtin_unreachable(); } +#undef MASK32_SET } static void send_websocket_pong(struct lwan_request *request, uint16_t header)