Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SSE2 vectorization for bitset::to_string #3960

Merged
merged 21 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions stl/inc/bitset
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@ _STL_DISABLE_CLANG_WARNINGS
#pragma push_macro("new")
#undef new

#if _USE_STD_VECTOR_ALGORITHMS
extern "C" {
__declspec(noalias) void __stdcall __std_bitset_to_string_1(
char* _Dest, const void* _Src, size_t _Size_bits, char _Elem0, char _Elem1) noexcept;
__declspec(noalias) void __stdcall __std_bitset_to_string_2(
wchar_t* _Dest, const void* _Src, size_t _Size_bits, wchar_t _Elem0, wchar_t _Elem1) noexcept;
} // extern "C"
#endif // _USE_STD_VECTOR_ALGORITHMS

_STD_BEGIN
_EXPORT_STD template <size_t _Bits>
class bitset { // store fixed-length sequence of Boolean elements
Expand Down Expand Up @@ -348,6 +357,23 @@ public:
// convert bitset to string
basic_string<_Elem, _Tr, _Alloc> _Str;
_Str._Resize_and_overwrite(_Bits, [this, _Elem0, _Elem1](_Elem* _Buf, size_t _Len) {
#if _USE_STD_VECTOR_ALGORITHMS
constexpr size_t _Bitset_vector_threshold = 32;
if constexpr (_Bits >= _Bitset_vector_threshold && is_integral_v<_Elem> && sizeof(_Elem) <= 2) {
if (!_Is_constant_evaluated()) {
if constexpr (sizeof(_Elem) == 1) {
__std_bitset_to_string_1(reinterpret_cast<char*>(_Buf), _Array, _Len, static_cast<char>(_Elem0),
static_cast<char>(_Elem1));
} else {
_STL_INTERNAL_STATIC_ASSERT(sizeof(_Elem) == 2);
__std_bitset_to_string_2(reinterpret_cast<wchar_t*>(_Buf), _Array, _Len,
static_cast<wchar_t>(_Elem0), static_cast<wchar_t>(_Elem1));
}
return _Len;
}
}
#endif // _USE_STD_VECTOR_ALGORITHMS

for (size_t _Pos = 0; _Pos < _Len; ++_Pos) {
_Buf[_Pos] = _Subscript(_Len - 1 - _Pos) ? _Elem1 : _Elem0;
}
Expand Down
110 changes: 110 additions & 0 deletions stl/src/vector_algorithms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#if defined(_M_IX86) || defined(_M_X64) // NB: includes _M_ARM64EC
#include <cstdint>
#include <cstring>
#ifndef _M_ARM64EC
#include <intrin.h>
#include <isa_availability.h>
Expand Down Expand Up @@ -1563,5 +1564,114 @@ __declspec(noalias) size_t
return __std_count_trivial_impl<_Find_traits_8>(_First, _Last, _Val);
}

} // extern "C"

#ifndef _M_ARM64EC
namespace {
__m128i __forceinline _Bitset_to_string_1_step(const uint16_t _Val, const __m128i _Px0, const __m128i _Px1) {
const __m128i _Vx0 = _mm_cvtsi32_si128(_Val);
const __m128i _Vx1 = _mm_unpacklo_epi8(_Vx0, _Vx0);
const __m128i _Vx2 = _mm_unpacklo_epi8(_Vx1, _Vx1);
const __m128i _Vx3 = _mm_shuffle_epi32(_Vx2, _MM_SHUFFLE(0, 0, 1, 1));
const __m128i _Msk = _mm_and_si128(_Vx3, _mm_set1_epi64x(0x0102040810204080));
const __m128i _Ex0 = _mm_cmpeq_epi8(_Msk, _mm_setzero_si128());
const __m128i _Ex1 = _mm_xor_si128(_mm_and_si128(_Ex0, _Px0), _Px1);
return _Ex1;
}

__m128i __forceinline _Bitset_to_string_2_step(const uint8_t _Val, const __m128i _Px0, const __m128i _Px1) {
const __m128i _Vx = _mm_set1_epi16(_Val);
const __m128i _Msk = _mm_and_si128(_Vx, _mm_set_epi64x(0x0001000200040008, 0x0010002000400080));
const __m128i _Ex0 = _mm_cmpeq_epi16(_Msk, _mm_setzero_si128());
const __m128i _Ex1 = _mm_xor_si128(_mm_and_si128(_Ex0, _Px0), _Px1);
return _Ex1;
}
} // unnamed namespace
#endif // !defined(_M_ARM64EC)

extern "C" {

__declspec(noalias) void __stdcall __std_bitset_to_string_1(
char* const _Dest, const void* _Src, size_t _Size_bits, const char _Elem0, const char _Elem1) noexcept {
#ifndef _M_ARM64EC
if (_Use_sse2()) {
const __m128i _Px0 = _mm_set1_epi8(_Elem0 ^ _Elem1);
const __m128i _Px1 = _mm_set1_epi8(_Elem1);
if (_Size_bits >= 16) {
char* _Pos = _Dest + _Size_bits;
_Size_bits &= 0xF;
char* const _Stop_at = _Dest + _Size_bits;
do {
uint16_t _Val;
memcpy(&_Val, _Src, 2);
AlexGuteniev marked this conversation as resolved.
Show resolved Hide resolved
const __m128i _Elems = _Bitset_to_string_1_step(_Val, _Px0, _Px1);
_Pos -= 16;
_mm_storeu_si128(reinterpret_cast<__m128i*>(_Pos), _Elems);
_Advance_bytes(_Src, 2);
} while (_Pos != _Stop_at);
}

if (_Size_bits > 0) {
__assume(_Size_bits < 16);
uint16_t _Val;
if (_Size_bits > 8) {
memcpy(&_Val, _Src, 2);
} else {
_Val = *reinterpret_cast<const uint8_t*>(_Src);
}
const __m128i _Elems = _Bitset_to_string_1_step(_Val, _Px0, _Px1);
char _Tmp[16];
_mm_storeu_si128(reinterpret_cast<__m128i*>(_Tmp), _Elems);
const char* const _Tmpd = _Tmp + (16 - _Size_bits);
for (size_t _Ix = 0; _Ix < _Size_bits; ++_Ix) {
_Dest[_Ix] = _Tmpd[_Ix];
}
}
}
#endif // !defined(_M_ARM64EC)
const auto _Arr = reinterpret_cast<const uint8_t*>(_Src);
for (size_t _Ix = 0; _Ix < _Size_bits; ++_Ix) {
_Dest[_Size_bits - 1 - _Ix] = ((_Arr[_Ix >> 3] >> (_Ix & 7)) & 1) != 0 ? _Elem1 : _Elem0;
}
}

__declspec(noalias) void __stdcall __std_bitset_to_string_2(
wchar_t* const _Dest, const void* _Src, size_t _Size_bits, const wchar_t _Elem0, const wchar_t _Elem1) noexcept {
#ifndef _M_ARM64EC
if (_Use_sse2()) {
const __m128i _Px0 = _mm_set1_epi16(_Elem0 ^ _Elem1);
const __m128i _Px1 = _mm_set1_epi16(_Elem1);
if (_Size_bits >= 8) {
wchar_t* _Pos = _Dest + _Size_bits;
_Size_bits &= 0x7;
wchar_t* const _Stop_at = _Dest + _Size_bits;
do {
const uint8_t _Val = *reinterpret_cast<const uint8_t*>(_Src);
const __m128i _Elems = _Bitset_to_string_2_step(_Val, _Px0, _Px1);
_Pos -= 8;
_mm_storeu_si128(reinterpret_cast<__m128i*>(_Pos), _Elems);
_Advance_bytes(_Src, 1);
} while (_Pos != _Stop_at);
}

if (_Size_bits > 0) {
__assume(_Size_bits < 8);
const uint8_t _Val = *reinterpret_cast<const uint8_t*>(_Src);
const __m128i _Elems = _Bitset_to_string_2_step(_Val, _Px0, _Px1);
wchar_t _Tmp[8];
_mm_storeu_si128(reinterpret_cast<__m128i*>(_Tmp), _Elems);
const wchar_t* const _Tmpd = _Tmp + (8 - _Size_bits);
for (size_t _Ix = 0; _Ix < _Size_bits; ++_Ix) {
_Dest[_Ix] = _Tmpd[_Ix];
}
}
}
#endif // !defined(_M_ARM64EC)
const auto _Arr = reinterpret_cast<const uint8_t*>(_Src);
for (size_t _Ix = 0; _Ix < _Size_bits; ++_Ix) {
_Dest[_Size_bits - 1 - _Ix] = ((_Arr[_Ix >> 3] >> (_Ix & 7)) & 1) != 0 ? _Elem1 : _Elem0;
}
}

} // extern "C"
#endif // defined(_M_IX86) || defined(_M_X64)
73 changes: 73 additions & 0 deletions tests/std/tests/VSO_0000000_vector_algorithms/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <algorithm>
#include <bitset>
#include <cassert>
#include <cstddef>
#include <cstdint>
Expand All @@ -12,6 +13,7 @@
#include <limits>
#include <list>
#include <random>
#include <string>
#include <type_traits>
#include <vector>

Expand Down Expand Up @@ -443,6 +445,73 @@ void test_one_container() {
test_two_containers<Container, list<int>>();
}

void test_bitset(mt19937_64& gen) {
assert(bitset<0>(0x0ULL).to_string() == "");
assert(bitset<0>(0xFEDCBA9876543210ULL).to_string() == "");
assert(bitset<15>(0x6789ULL).to_string() == "110011110001001");
assert(bitset<15>(0xFEDCBA9876543210ULL).to_string() == "011001000010000");
assert(bitset<32>(0xABCD1234ULL).to_string() == "10101011110011010001001000110100");
assert(bitset<32>(0xFEDCBA9876543210ULL).to_string() == "01110110010101000011001000010000");
assert(bitset<45>(0x1701D1729FFFULL).to_string() == "101110000000111010001011100101001111111111111");
assert(bitset<45>(0xFEDCBA9876543210ULL).to_string() == "110101001100001110110010101000011001000010000");
assert(bitset<64>(0xFEDCBA9876543210ULL).to_string()
== "1111111011011100101110101001100001110110010101000011001000010000");
assert(bitset<75>(0xFEDCBA9876543210ULL).to_string()
== "000000000001111111011011100101110101001100001110110010101000011001000010000");
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved

assert(bitset<0>(0x0ULL).to_string<wchar_t>() == L"");
assert(bitset<0>(0xFEDCBA9876543210ULL).to_string<wchar_t>() == L"");
assert(bitset<15>(0x6789ULL).to_string<wchar_t>() == L"110011110001001");
assert(bitset<15>(0xFEDCBA9876543210ULL).to_string<wchar_t>() == L"011001000010000");
assert(bitset<32>(0xABCD1234ULL).to_string<wchar_t>() == L"10101011110011010001001000110100");
assert(bitset<32>(0xFEDCBA9876543210ULL).to_string<wchar_t>() == L"01110110010101000011001000010000");
assert(bitset<45>(0x1701D1729FFFULL).to_string<wchar_t>() == L"101110000000111010001011100101001111111111111");
assert(bitset<45>(0xFEDCBA9876543210ULL).to_string<wchar_t>() == L"110101001100001110110010101000011001000010000");
assert(bitset<64>(0xFEDCBA9876543210ULL).to_string<wchar_t>()
== L"1111111011011100101110101001100001110110010101000011001000010000");
assert(bitset<75>(0xFEDCBA9876543210ULL).to_string<wchar_t>()
== L"000000000001111111011011100101110101001100001110110010101000011001000010000");

assert(bitset<64>(0xFEDCBA9876543210ULL).to_string('o', 'x')
== "xxxxxxxoxxoxxxooxoxxxoxoxooxxooooxxxoxxooxoxoxooooxxooxooooxoooo");
assert(bitset<64>(0xFEDCBA9876543210ULL).to_string<wchar_t>(L'o', L'x')
== L"xxxxxxxoxxoxxxooxoxxxoxoxooxxooooxxxoxxooxoxoxooooxxooxooooxoooo");

StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
#ifdef __cpp_lib_char8_t
assert(bitset<75>(0xFEDCBA9876543210ULL).to_string<char8_t>()
== u8"000000000001111111011011100101110101001100001110110010101000011001000010000");
#endif // __cpp_lib_char8_t
assert(bitset<75>(0xFEDCBA9876543210ULL).to_string<char16_t>()
== u"000000000001111111011011100101110101001100001110110010101000011001000010000");
assert(bitset<75>(0xFEDCBA9876543210ULL).to_string<char32_t>()
== U"000000000001111111011011100101110101001100001110110010101000011001000010000"); // not vectorized

{
constexpr size_t N = 2048;

string str;
wstring wstr;
str.reserve(N);
wstr.reserve(N);

while (str.size() != N) {
uint64_t random_value = gen();

for (int bits = 0; bits < 64; ++bits) {
const auto character = '0' + (random_value & 1);
str.push_back(static_cast<char>(character));
wstr.push_back(static_cast<wchar_t>(character));
random_value >>= 1;
}
}

const bitset<N> b(str);

assert(b.to_string() == str);
assert(b.to_string<wchar_t>() == wstr);
}
}

void test_various_containers() {
test_one_container<vector<int>>(); // contiguous, vectorizable
test_one_container<deque<int>>(); // random-access, not vectorizable
Expand Down Expand Up @@ -524,20 +593,24 @@ int main() {

test_vector_algorithms(gen);
test_various_containers();
test_bitset(gen);
#ifndef _M_CEE_PURE
#if defined(_M_IX86) || defined(_M_X64)
disable_instructions(__ISA_AVAILABLE_AVX2);
test_vector_algorithms(gen);
test_various_containers();
test_bitset(gen);

disable_instructions(__ISA_AVAILABLE_SSE42);
test_vector_algorithms(gen);
test_various_containers();
test_bitset(gen);
#endif // defined(_M_IX86) || defined(_M_X64)
#if defined(_M_IX86)
disable_instructions(__ISA_AVAILABLE_SSE2);
test_vector_algorithms(gen);
test_various_containers();
test_bitset(gen);
#endif // defined(_M_IX86)
#endif // _M_CEE_PURE
}