From ffd735ac1aacc9623fefc885be66fd9dce1ce310 Mon Sep 17 00:00:00 2001 From: Alex Guteniev Date: Thu, 28 Mar 2024 18:35:20 +0200 Subject: [PATCH] `mismatch` vectorization (#4495) Co-authored-by: Stephan T. Lavavej --- benchmarks/CMakeLists.txt | 1 + benchmarks/src/mismatch.cpp | 36 ++++ stl/inc/algorithm | 34 ++++ stl/inc/xutility | 35 ++++ stl/src/vector_algorithms.cpp | 99 ++++++++++ .../VSO_0000000_vector_algorithms/test.cpp | 183 ++++++++++++++++++ 6 files changed, 388 insertions(+) create mode 100644 benchmarks/src/mismatch.cpp diff --git a/benchmarks/CMakeLists.txt b/benchmarks/CMakeLists.txt index a45c736901..0f22a1f003 100644 --- a/benchmarks/CMakeLists.txt +++ b/benchmarks/CMakeLists.txt @@ -113,6 +113,7 @@ add_benchmark(find_and_count src/find_and_count.cpp) add_benchmark(find_first_of src/find_first_of.cpp) add_benchmark(locale_classic src/locale_classic.cpp) add_benchmark(minmax_element src/minmax_element.cpp) +add_benchmark(mismatch src/mismatch.cpp) add_benchmark(path_lexically_normal src/path_lexically_normal.cpp) add_benchmark(priority_queue_push_range src/priority_queue_push_range.cpp) add_benchmark(random_integer_generation src/random_integer_generation.cpp) diff --git a/benchmarks/src/mismatch.cpp b/benchmarks/src/mismatch.cpp new file mode 100644 index 0000000000..f6a3069ac2 --- /dev/null +++ b/benchmarks/src/mismatch.cpp @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include +#include +#include +#include +#include + +using namespace std; + +constexpr int64_t no_pos = -1; + +template +void bm(benchmark::State& state) { + vector a(static_cast(state.range(0)), T{'.'}); + vector b(static_cast(state.range(0)), T{'.'}); + + if (state.range(1) != no_pos) { + b.at(static_cast(state.range(1))) = 'x'; + } + + for (auto _ : state) { + benchmark::DoNotOptimize(ranges::mismatch(a, b)); + } +} + +#define COMMON_ARGS Args({8, 3})->Args({24, 22})->Args({105, -1})->Args({4021, 3056}) + +BENCHMARK(bm)->COMMON_ARGS; +BENCHMARK(bm)->COMMON_ARGS; +BENCHMARK(bm)->COMMON_ARGS; +BENCHMARK(bm)->COMMON_ARGS; + +BENCHMARK_MAIN(); diff --git a/stl/inc/algorithm b/stl/inc/algorithm index fd9882d74c..ff1f8b65e3 100644 --- a/stl/inc/algorithm +++ b/stl/inc/algorithm @@ -669,6 +669,23 @@ _NODISCARD _CONSTEXPR20 pair<_InIt1, _InIt2> mismatch(_InIt1 _First1, const _InI auto _UFirst1 = _STD _Get_unwrapped(_First1); const auto _ULast1 = _STD _Get_unwrapped(_Last1); auto _UFirst2 = _STD _Get_unwrapped_n(_First2, _STD _Idl_distance<_InIt1>(_UFirst1, _ULast1)); +#if _USE_STD_VECTOR_ALGORITHMS + if constexpr (_Equal_memcmp_is_safe) { + if (!_STD _Is_constant_evaluated()) { + constexpr size_t _Elem_size = sizeof(_Iter_value_t<_InIt1>); + + const size_t _Pos = _STD __std_mismatch<_Elem_size>( + _STD _To_address(_UFirst1), _STD _To_address(_UFirst2), static_cast(_ULast1 - _UFirst1)); + + _UFirst1 += static_cast<_Iter_diff_t<_InIt1>>(_Pos); + _UFirst2 += static_cast<_Iter_diff_t<_InIt2>>(_Pos); + + _STD _Seek_wrapped(_First2, _UFirst2); + _STD _Seek_wrapped(_First1, _UFirst1); + return {_First1, _First2}; + } + } +#endif // ^^^ _USE_STD_VECTOR_ALGORITHMS ^^^ while (_UFirst1 != _ULast1 && _Pred(*_UFirst1, *_UFirst2)) { ++_UFirst1; ++_UFirst2; @@ -716,6 +733,23 @@ _NODISCARD _CONSTEXPR20 pair<_InIt1, _InIt2> mismatch( const _CT _Count2 = _ULast2 - _UFirst2; const auto _Count = static_cast<_Iter_diff_t<_InIt1>>((_STD min)(_Count1, _Count2)); _ULast1 = _UFirst1 + _Count; +#if _USE_STD_VECTOR_ALGORITHMS + if constexpr (_Equal_memcmp_is_safe) { + if (!_STD _Is_constant_evaluated()) { + constexpr size_t _Elem_size = sizeof(_Iter_value_t<_InIt1>); + + const size_t _Pos = _STD __std_mismatch<_Elem_size>( + _STD _To_address(_UFirst1), _STD _To_address(_UFirst2), static_cast(_Count)); + + _UFirst1 += static_cast<_Iter_diff_t<_InIt1>>(_Pos); + _UFirst2 += static_cast<_Iter_diff_t<_InIt2>>(_Pos); + + _STD _Seek_wrapped(_First2, _UFirst2); + _STD _Seek_wrapped(_First1, _UFirst1); + return {_First1, _First2}; + } + } +#endif // ^^^ _USE_STD_VECTOR_ALGORITHMS ^^^ while (_UFirst1 != _ULast1 && _Pred(*_UFirst1, *_UFirst2)) { ++_UFirst1; ++_UFirst2; diff --git a/stl/inc/xutility b/stl/inc/xutility index 260c44f211..f2e0482894 100644 --- a/stl/inc/xutility +++ b/stl/inc/xutility @@ -124,6 +124,11 @@ __declspec(noalias) int64_t __stdcall __std_max_8i(const void* _First, const voi __declspec(noalias) uint64_t __stdcall __std_max_8u(const void* _First, const void* _Last) noexcept; __declspec(noalias) float __stdcall __std_max_f(const void* _First, const void* _Last) noexcept; __declspec(noalias) double __stdcall __std_max_d(const void* _First, const void* _Last) noexcept; + +__declspec(noalias) size_t __stdcall __std_mismatch_1(const void* _First1, const void* _First2, size_t _Count) noexcept; +__declspec(noalias) size_t __stdcall __std_mismatch_2(const void* _First1, const void* _First2, size_t _Count) noexcept; +__declspec(noalias) size_t __stdcall __std_mismatch_4(const void* _First1, const void* _First2, size_t _Count) noexcept; +__declspec(noalias) size_t __stdcall __std_mismatch_8(const void* _First1, const void* _First2, size_t _Count) noexcept; } // extern "C" _STD_BEGIN @@ -292,6 +297,22 @@ auto __std_max(_Ty* const _First, _Ty* const _Last) noexcept { static_assert(_Always_false<_Ty>, "Unexpected size"); } } + +template +inline size_t // TRANSITION, GH-4496 + __std_mismatch(const void* const _First1, const void* const _First2, const size_t _Count) noexcept { + if constexpr (_Element_size == 1) { + return __std_mismatch_1(_First1, _First2, _Count); + } else if constexpr (_Element_size == 2) { + return __std_mismatch_2(_First1, _First2, _Count); + } else if constexpr (_Element_size == 4) { + return __std_mismatch_4(_First1, _First2, _Count); + } else if constexpr (_Element_size == 8) { + return __std_mismatch_8(_First1, _First2, _Count); + } else { + static_assert(_Always_false>, "Unexpected size"); + } +} _STD_END #endif // _USE_STD_VECTOR_ALGORITHMS @@ -5477,6 +5498,20 @@ namespace ranges { _NODISCARD constexpr mismatch_result<_It1, _It2> _Mismatch_n( _It1 _First1, _It2 _First2, iter_difference_t<_It1> _Count, _Pr _Pred, _Pj1 _Proj1, _Pj2 _Proj2) { _STL_INTERNAL_CHECK(_Count >= 0); +#if _USE_STD_VECTOR_ALGORITHMS + if constexpr (_Equal_memcmp_is_safe<_It1, _It2, _Pr> && is_same_v<_Pj1, identity> + && is_same_v<_Pj2, identity>) { + if (!_STD is_constant_evaluated()) { + constexpr size_t _Elem_size = sizeof(iter_value_t<_It1>); + + const size_t _Pos = _STD __std_mismatch<_Elem_size>( + _STD _To_address(_First1), _STD _To_address(_First2), static_cast(_Count)); + + return {_First1 + static_cast>(_Pos), + _First2 + static_cast>(_Pos)}; + } + } +#endif // ^^^ _USE_STD_VECTOR_ALGORITHMS ^^^ for (; _Count != 0; ++_First1, (void) ++_First2, --_Count) { if (!_STD invoke(_Pred, _STD invoke(_Proj1, *_First1), _STD invoke(_Proj2, *_First2))) { break; diff --git a/stl/src/vector_algorithms.cpp b/stl/src/vector_algorithms.cpp index ead6d699cf..a5a637695e 100644 --- a/stl/src/vector_algorithms.cpp +++ b/stl/src/vector_algorithms.cpp @@ -88,6 +88,12 @@ namespace { void _Advance_bytes(const void*& _Target, _Integral _Offset) noexcept { _Target = static_cast(_Target) + _Offset; } + + __m256i _Avx2_tail_mask_32(const size_t _Count_in_dwords) noexcept { + // _Count_in_dwords must be within [1, 7]. + static constexpr unsigned int _Tail_masks[14] = {~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, 0, 0, 0, 0, 0, 0, 0}; + return _mm256_loadu_si256(reinterpret_cast(_Tail_masks + (7 - _Count_in_dwords))); + } } // unnamed namespace extern "C" { @@ -2077,6 +2083,79 @@ namespace { return _Ptr_haystack; } + + template + __declspec(noalias) size_t __stdcall __std_mismatch_impl( + const void* const _First1, const void* const _First2, const size_t _Count) noexcept { + size_t _Result = 0; +#ifndef _M_ARM64EC + const auto _First1_ch = static_cast(_First1); + const auto _First2_ch = static_cast(_First2); + + if (_Use_avx2()) { + const size_t _Count_bytes = _Count * sizeof(_Ty); + const size_t _Count_bytes_avx_full = _Count_bytes & ~size_t{0x1F}; + + for (; _Result != _Count_bytes_avx_full; _Result += 0x20) { + const __m256i _Elem1 = _mm256_loadu_si256(reinterpret_cast(_First1_ch + _Result)); + const __m256i _Elem2 = _mm256_loadu_si256(reinterpret_cast(_First2_ch + _Result)); + const auto _Bingo = ~static_cast(_mm256_movemask_epi8(_Traits::_Cmp_avx(_Elem1, _Elem2))); + if (_Bingo != 0) { + return (_Result + _tzcnt_u32(_Bingo)) / sizeof(_Ty); + } + } + + const size_t _Count_tail = _Count_bytes & size_t{0x1C}; + + if (_Count_tail != 0) { + const __m256i _Tail_mask = _Avx2_tail_mask_32(_Count_tail >> 2); + const __m256i _Elem1 = + _mm256_maskload_epi32(reinterpret_cast(_First1_ch + _Result), _Tail_mask); + const __m256i _Elem2 = + _mm256_maskload_epi32(reinterpret_cast(_First2_ch + _Result), _Tail_mask); + + const auto _Bingo = ~static_cast(_mm256_movemask_epi8(_Traits::_Cmp_avx(_Elem1, _Elem2))); + if (_Bingo != 0) { + return (_Result + _tzcnt_u32(_Bingo)) / sizeof(_Ty); + } + + _Result += _Count_tail; + } + + _Result /= sizeof(_Ty); + + if constexpr (sizeof(_Ty) >= 4) { + return _Result; + } + } else if (_Traits::_Sse_available()) { + const size_t _Count_bytes_sse = (_Count * sizeof(_Ty)) & ~size_t{0xF}; + + for (; _Result != _Count_bytes_sse; _Result += 0x10) { + const __m128i _Elem1 = _mm_loadu_si128(reinterpret_cast(_First1_ch + _Result)); + const __m128i _Elem2 = _mm_loadu_si128(reinterpret_cast(_First2_ch + _Result)); + const auto _Bingo = + static_cast(_mm_movemask_epi8(_Traits::_Cmp_sse(_Elem1, _Elem2))) ^ 0xFFFF; + if (_Bingo != 0) { + unsigned long _Offset; + _BitScanForward(&_Offset, _Bingo); // lgtm [cpp/conditionallyuninitializedvariable] + return (_Result + _Offset) / sizeof(_Ty); + } + } + + _Result /= sizeof(_Ty); + } +#endif // !defined(_M_ARM64EC) + const auto _First1_el = static_cast(_First1); + const auto _First2_el = static_cast(_First2); + + for (; _Result != _Count; ++_Result) { + if (_First1_el[_Result] != _First2_el[_Result]) { + break; + } + } + + return _Result; + } } // unnamed namespace extern "C" { @@ -2172,6 +2251,26 @@ const void* __stdcall __std_find_first_of_trivial_2( return __std_find_first_of_trivial_impl(_First1, _Last1, _First2, _Last2); } +__declspec(noalias) size_t + __stdcall __std_mismatch_1(const void* const _First1, const void* const _First2, const size_t _Count) noexcept { + return __std_mismatch_impl<_Find_traits_1, uint8_t>(_First1, _First2, _Count); +} + +__declspec(noalias) size_t + __stdcall __std_mismatch_2(const void* const _First1, const void* const _First2, const size_t _Count) noexcept { + return __std_mismatch_impl<_Find_traits_2, uint16_t>(_First1, _First2, _Count); +} + +__declspec(noalias) size_t + __stdcall __std_mismatch_4(const void* const _First1, const void* const _First2, const size_t _Count) noexcept { + return __std_mismatch_impl<_Find_traits_4, uint32_t>(_First1, _First2, _Count); +} + +__declspec(noalias) size_t + __stdcall __std_mismatch_8(const void* const _First1, const void* const _First2, const size_t _Count) noexcept { + return __std_mismatch_impl<_Find_traits_8, uint64_t>(_First1, _First2, _Count); +} + } // extern "C" #ifndef _M_ARM64EC diff --git a/tests/std/tests/VSO_0000000_vector_algorithms/test.cpp b/tests/std/tests/VSO_0000000_vector_algorithms/test.cpp index af4f710cdd..e263e59d62 100644 --- a/tests/std/tests/VSO_0000000_vector_algorithms/test.cpp +++ b/tests/std/tests/VSO_0000000_vector_algorithms/test.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -376,6 +377,165 @@ void test_min_max_element_special_cases() { == v.begin() + 2 * block_size_in_elements + last_vector_first_elem + 9); } +template +auto last_known_good_mismatch(FwdIt first1, FwdIt last1, FwdIt first2, FwdIt last2) { + for (; first1 != last1 && first2 != last2; ++first1, ++first2) { + if (*first1 != *first2) { + break; + } + } + + return make_pair(first1, first2); +} + +template +void test_case_mismatch(const vector& a, const vector& b) { + auto expected = last_known_good_mismatch(a.begin(), a.end(), b.begin(), b.end()); + auto actual = mismatch(a.begin(), a.end(), b.begin(), b.end()); + assert(expected == actual); +#if _HAS_CXX20 + auto ranges_actual = ranges::mismatch(a, b); + assert(get<0>(expected) == ranges_actual.in1); + assert(get<1>(expected) == ranges_actual.in2); +#endif // _HAS_CXX20 +} + +template +void test_mismatch(mt19937_64& gen) { + constexpr size_t shrinkCount = 4; + constexpr size_t mismatchCount = 30; + using TD = conditional_t; + uniform_int_distribution dis('a', 'z'); + vector input_a; + vector input_b; + input_a.reserve(dataCount); + input_b.reserve(dataCount); + + for (;;) { + // equal + test_case_mismatch(input_a, input_b); + + // different sizes + for (size_t i = 0; i != shrinkCount && !input_b.empty(); ++i) { + test_case_mismatch(input_a, input_b); + test_case_mismatch(input_b, input_a); + input_b.pop_back(); + } + + // actual mismatch (or maybe not, depending on random) + if (!input_b.empty()) { + uniform_int_distribution mismatch_dis(0, input_a.size() - 1); + + for (size_t attempts = 0; attempts < mismatchCount; ++attempts) { + const size_t possible_mismatch_pos = mismatch_dis(gen); + input_a[possible_mismatch_pos] = static_cast(dis(gen)); + test_case_mismatch(input_a, input_b); + test_case_mismatch(input_b, input_a); + } + } + + if (input_a.size() == dataCount) { + break; + } + + input_a.push_back(static_cast(dis(gen))); + input_b = input_a; + } +} + +template +void test_mismatch_containers() { + C1 a{'m', 'e', 'o', 'w', ' ', 'C', 'A', 'T', 'S'}; + C2 b{'m', 'e', 'o', 'w', ' ', 'K', 'I', 'T', 'T', 'E', 'N', 'S'}; + const auto result_4 = mismatch(a.begin(), a.end(), b.begin(), b.end()); + const auto result_3 = mismatch(a.begin(), a.end(), b.begin()); + assert(get<0>(result_4) == a.begin() + 5); + assert(get<1>(result_4) == b.begin() + 5); + assert(get<0>(result_3) == a.begin() + 5); + assert(get<1>(result_3) == b.begin() + 5); +#if _HAS_CXX20 + const auto result_r = ranges::mismatch(a, b); + assert(result_r.in1 == a.begin() + 5); + assert(result_r.in2 == b.begin() + 5); +#endif // _HAS_CXX20 +} + +namespace test_mismatch_sizes_and_alignments { + constexpr size_t range = 33; + constexpr size_t alignment = 32; + +#pragma pack(push, 1) + template + struct with_pad { + char p[PadSize]; + T v[Size]; + }; +#pragma pack(pop) + + template + char stack_array_various_alignments_impl() { + with_pad a = {}; + with_pad b = {}; + assert(mismatch(begin(a.v), end(a.v), begin(b.v), end(b.v)) == make_pair(end(a.v), end(b.v))); + return 0; + } + + template + void stack_array_various_alignments(index_sequence) { + char ignored[] = {stack_array_various_alignments_impl()...}; + (void) ignored; + } + + template + char stack_array_impl() { + T a[Size + 1] = {}; + T b[Size + 1] = {}; + assert(mismatch(begin(a), end(a), begin(b), end(b)) == make_pair(end(a), end(b))); + stack_array_various_alignments(make_index_sequence{}); + return 0; + } + + template + void stack_array(index_sequence) { + char ignored[] = {stack_array_impl()...}; + (void) ignored; + } + + template + void test() { + // stack with different sizes and alignments. ASan would catch out-of-range reads + stack_array(make_index_sequence{}); + + // vector with different sizes. ASan vector annotations would catch out-of-range reads + for (size_t i = 0; i != range; ++i) { + vector a(i, 0); + vector b(i, 0); + assert(mismatch(begin(a), end(a), begin(b), end(b)) == make_pair(end(a), end(b))); + } + + // heap with different sizes. ASan would catch out-of-range reads + for (size_t i = 0; i != range; ++i) { + T* a = static_cast(calloc(i, sizeof(T))); + T* b = static_cast(calloc(i, sizeof(T))); + assert(mismatch(a, a + i, b, b + i) == make_pair(a + i, b + i)); + free(a); + free(b); + } + + // subarray from stack array. We would have wrong results if we run out of the range. + T a[range + 1] = {}; + T b[range + 1] = {}; + for (size_t i = 0; i != range; ++i) { + a[i + 1] = 1; + // whole range mismatch finds mismatch after past-the-end of the subarray + assert(mismatch(a, a + range + 1, b, b + range + 1) == make_pair(a + i + 1, b + i + 1)); + // limited range mismatch gets to past-the-end of the subarray + assert(mismatch(a, a + i, b, b + i) == make_pair(a + i, b + i)); + a[i + 1] = 0; + } + } +} // namespace test_mismatch_sizes_and_alignments + template void last_known_good_reverse(BidIt first, BidIt last) { for (; first != last && first != --last; ++first) { @@ -545,6 +705,29 @@ void test_vector_algorithms(mt19937_64& gen) { test_case_min_max_element( vector{-6604286336755016904, -4365366089374418225, 6104371530830675888, -8582621853879131834}); + test_mismatch(gen); + test_mismatch(gen); + test_mismatch(gen); + test_mismatch(gen); + test_mismatch(gen); + test_mismatch(gen); + test_mismatch(gen); + test_mismatch(gen); + test_mismatch(gen); + + test_mismatch_containers, vector>(); + test_mismatch_containers, vector>(); + test_mismatch_containers, vector>(); + test_mismatch_containers, const vector>(); + test_mismatch_containers, const vector>(); + test_mismatch_containers, vector>(); + test_mismatch_containers, vector>(); + + test_mismatch_sizes_and_alignments::test(); + test_mismatch_sizes_and_alignments::test(); + test_mismatch_sizes_and_alignments::test(); + test_mismatch_sizes_and_alignments::test(); + test_reverse(gen); test_reverse(gen); test_reverse(gen);