Skip to content

Commit

Permalink
mismatch vectorization (#4495)
Browse files Browse the repository at this point in the history
Co-authored-by: Stephan T. Lavavej <stl@nuwen.net>
  • Loading branch information
AlexGuteniev and StephanTLavavej authored Mar 28, 2024
1 parent 8e2d724 commit ffd735a
Show file tree
Hide file tree
Showing 6 changed files with 388 additions and 0 deletions.
1 change: 1 addition & 0 deletions benchmarks/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ add_benchmark(find_and_count src/find_and_count.cpp)
add_benchmark(find_first_of src/find_first_of.cpp)
add_benchmark(locale_classic src/locale_classic.cpp)
add_benchmark(minmax_element src/minmax_element.cpp)
add_benchmark(mismatch src/mismatch.cpp)
add_benchmark(path_lexically_normal src/path_lexically_normal.cpp)
add_benchmark(priority_queue_push_range src/priority_queue_push_range.cpp)
add_benchmark(random_integer_generation src/random_integer_generation.cpp)
Expand Down
36 changes: 36 additions & 0 deletions benchmarks/src/mismatch.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <algorithm>
#include <benchmark/benchmark.h>
#include <cstddef>
#include <cstdint>
#include <ranges>
#include <vector>

using namespace std;

constexpr int64_t no_pos = -1;

template <class T>
void bm(benchmark::State& state) {
vector<T> a(static_cast<size_t>(state.range(0)), T{'.'});
vector<T> b(static_cast<size_t>(state.range(0)), T{'.'});

if (state.range(1) != no_pos) {
b.at(static_cast<size_t>(state.range(1))) = 'x';
}

for (auto _ : state) {
benchmark::DoNotOptimize(ranges::mismatch(a, b));
}
}

#define COMMON_ARGS Args({8, 3})->Args({24, 22})->Args({105, -1})->Args({4021, 3056})

BENCHMARK(bm<uint8_t>)->COMMON_ARGS;
BENCHMARK(bm<uint16_t>)->COMMON_ARGS;
BENCHMARK(bm<uint32_t>)->COMMON_ARGS;
BENCHMARK(bm<uint64_t>)->COMMON_ARGS;

BENCHMARK_MAIN();
34 changes: 34 additions & 0 deletions stl/inc/algorithm
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,23 @@ _NODISCARD _CONSTEXPR20 pair<_InIt1, _InIt2> mismatch(_InIt1 _First1, const _InI
auto _UFirst1 = _STD _Get_unwrapped(_First1);
const auto _ULast1 = _STD _Get_unwrapped(_Last1);
auto _UFirst2 = _STD _Get_unwrapped_n(_First2, _STD _Idl_distance<_InIt1>(_UFirst1, _ULast1));
#if _USE_STD_VECTOR_ALGORITHMS
if constexpr (_Equal_memcmp_is_safe<decltype(_UFirst1), decltype(_UFirst2), _Pr>) {
if (!_STD _Is_constant_evaluated()) {
constexpr size_t _Elem_size = sizeof(_Iter_value_t<_InIt1>);

const size_t _Pos = _STD __std_mismatch<_Elem_size>(
_STD _To_address(_UFirst1), _STD _To_address(_UFirst2), static_cast<size_t>(_ULast1 - _UFirst1));

_UFirst1 += static_cast<_Iter_diff_t<_InIt1>>(_Pos);
_UFirst2 += static_cast<_Iter_diff_t<_InIt2>>(_Pos);

_STD _Seek_wrapped(_First2, _UFirst2);
_STD _Seek_wrapped(_First1, _UFirst1);
return {_First1, _First2};
}
}
#endif // ^^^ _USE_STD_VECTOR_ALGORITHMS ^^^
while (_UFirst1 != _ULast1 && _Pred(*_UFirst1, *_UFirst2)) {
++_UFirst1;
++_UFirst2;
Expand Down Expand Up @@ -716,6 +733,23 @@ _NODISCARD _CONSTEXPR20 pair<_InIt1, _InIt2> mismatch(
const _CT _Count2 = _ULast2 - _UFirst2;
const auto _Count = static_cast<_Iter_diff_t<_InIt1>>((_STD min)(_Count1, _Count2));
_ULast1 = _UFirst1 + _Count;
#if _USE_STD_VECTOR_ALGORITHMS
if constexpr (_Equal_memcmp_is_safe<decltype(_UFirst1), decltype(_UFirst2), _Pr>) {
if (!_STD _Is_constant_evaluated()) {
constexpr size_t _Elem_size = sizeof(_Iter_value_t<_InIt1>);

const size_t _Pos = _STD __std_mismatch<_Elem_size>(
_STD _To_address(_UFirst1), _STD _To_address(_UFirst2), static_cast<size_t>(_Count));

_UFirst1 += static_cast<_Iter_diff_t<_InIt1>>(_Pos);
_UFirst2 += static_cast<_Iter_diff_t<_InIt2>>(_Pos);

_STD _Seek_wrapped(_First2, _UFirst2);
_STD _Seek_wrapped(_First1, _UFirst1);
return {_First1, _First2};
}
}
#endif // ^^^ _USE_STD_VECTOR_ALGORITHMS ^^^
while (_UFirst1 != _ULast1 && _Pred(*_UFirst1, *_UFirst2)) {
++_UFirst1;
++_UFirst2;
Expand Down
35 changes: 35 additions & 0 deletions stl/inc/xutility
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ __declspec(noalias) int64_t __stdcall __std_max_8i(const void* _First, const voi
__declspec(noalias) uint64_t __stdcall __std_max_8u(const void* _First, const void* _Last) noexcept;
__declspec(noalias) float __stdcall __std_max_f(const void* _First, const void* _Last) noexcept;
__declspec(noalias) double __stdcall __std_max_d(const void* _First, const void* _Last) noexcept;

__declspec(noalias) size_t __stdcall __std_mismatch_1(const void* _First1, const void* _First2, size_t _Count) noexcept;
__declspec(noalias) size_t __stdcall __std_mismatch_2(const void* _First1, const void* _First2, size_t _Count) noexcept;
__declspec(noalias) size_t __stdcall __std_mismatch_4(const void* _First1, const void* _First2, size_t _Count) noexcept;
__declspec(noalias) size_t __stdcall __std_mismatch_8(const void* _First1, const void* _First2, size_t _Count) noexcept;
} // extern "C"

_STD_BEGIN
Expand Down Expand Up @@ -292,6 +297,22 @@ auto __std_max(_Ty* const _First, _Ty* const _Last) noexcept {
static_assert(_Always_false<_Ty>, "Unexpected size");
}
}

template <size_t _Element_size>
inline size_t // TRANSITION, GH-4496
__std_mismatch(const void* const _First1, const void* const _First2, const size_t _Count) noexcept {
if constexpr (_Element_size == 1) {
return __std_mismatch_1(_First1, _First2, _Count);
} else if constexpr (_Element_size == 2) {
return __std_mismatch_2(_First1, _First2, _Count);
} else if constexpr (_Element_size == 4) {
return __std_mismatch_4(_First1, _First2, _Count);
} else if constexpr (_Element_size == 8) {
return __std_mismatch_8(_First1, _First2, _Count);
} else {
static_assert(_Always_false<integral_constant<size_t, _Element_size>>, "Unexpected size");
}
}
_STD_END

#endif // _USE_STD_VECTOR_ALGORITHMS
Expand Down Expand Up @@ -5477,6 +5498,20 @@ namespace ranges {
_NODISCARD constexpr mismatch_result<_It1, _It2> _Mismatch_n(
_It1 _First1, _It2 _First2, iter_difference_t<_It1> _Count, _Pr _Pred, _Pj1 _Proj1, _Pj2 _Proj2) {
_STL_INTERNAL_CHECK(_Count >= 0);
#if _USE_STD_VECTOR_ALGORITHMS
if constexpr (_Equal_memcmp_is_safe<_It1, _It2, _Pr> && is_same_v<_Pj1, identity>
&& is_same_v<_Pj2, identity>) {
if (!_STD is_constant_evaluated()) {
constexpr size_t _Elem_size = sizeof(iter_value_t<_It1>);

const size_t _Pos = _STD __std_mismatch<_Elem_size>(
_STD _To_address(_First1), _STD _To_address(_First2), static_cast<size_t>(_Count));

return {_First1 + static_cast<iter_difference_t<_It1>>(_Pos),
_First2 + static_cast<iter_difference_t<_It2>>(_Pos)};
}
}
#endif // ^^^ _USE_STD_VECTOR_ALGORITHMS ^^^
for (; _Count != 0; ++_First1, (void) ++_First2, --_Count) {
if (!_STD invoke(_Pred, _STD invoke(_Proj1, *_First1), _STD invoke(_Proj2, *_First2))) {
break;
Expand Down
99 changes: 99 additions & 0 deletions stl/src/vector_algorithms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ namespace {
void _Advance_bytes(const void*& _Target, _Integral _Offset) noexcept {
_Target = static_cast<const unsigned char*>(_Target) + _Offset;
}

__m256i _Avx2_tail_mask_32(const size_t _Count_in_dwords) noexcept {
// _Count_in_dwords must be within [1, 7].
static constexpr unsigned int _Tail_masks[14] = {~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, 0, 0, 0, 0, 0, 0, 0};
return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(_Tail_masks + (7 - _Count_in_dwords)));
}
} // unnamed namespace

extern "C" {
Expand Down Expand Up @@ -2077,6 +2083,79 @@ namespace {
return _Ptr_haystack;
}


template <class _Traits, class _Ty>
__declspec(noalias) size_t __stdcall __std_mismatch_impl(
const void* const _First1, const void* const _First2, const size_t _Count) noexcept {
size_t _Result = 0;
#ifndef _M_ARM64EC
const auto _First1_ch = static_cast<const char*>(_First1);
const auto _First2_ch = static_cast<const char*>(_First2);

if (_Use_avx2()) {
const size_t _Count_bytes = _Count * sizeof(_Ty);
const size_t _Count_bytes_avx_full = _Count_bytes & ~size_t{0x1F};

for (; _Result != _Count_bytes_avx_full; _Result += 0x20) {
const __m256i _Elem1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(_First1_ch + _Result));
const __m256i _Elem2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(_First2_ch + _Result));
const auto _Bingo = ~static_cast<unsigned int>(_mm256_movemask_epi8(_Traits::_Cmp_avx(_Elem1, _Elem2)));
if (_Bingo != 0) {
return (_Result + _tzcnt_u32(_Bingo)) / sizeof(_Ty);
}
}

const size_t _Count_tail = _Count_bytes & size_t{0x1C};

if (_Count_tail != 0) {
const __m256i _Tail_mask = _Avx2_tail_mask_32(_Count_tail >> 2);
const __m256i _Elem1 =
_mm256_maskload_epi32(reinterpret_cast<const int*>(_First1_ch + _Result), _Tail_mask);
const __m256i _Elem2 =
_mm256_maskload_epi32(reinterpret_cast<const int*>(_First2_ch + _Result), _Tail_mask);

const auto _Bingo = ~static_cast<unsigned int>(_mm256_movemask_epi8(_Traits::_Cmp_avx(_Elem1, _Elem2)));
if (_Bingo != 0) {
return (_Result + _tzcnt_u32(_Bingo)) / sizeof(_Ty);
}

_Result += _Count_tail;
}

_Result /= sizeof(_Ty);

if constexpr (sizeof(_Ty) >= 4) {
return _Result;
}
} else if (_Traits::_Sse_available()) {
const size_t _Count_bytes_sse = (_Count * sizeof(_Ty)) & ~size_t{0xF};

for (; _Result != _Count_bytes_sse; _Result += 0x10) {
const __m128i _Elem1 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(_First1_ch + _Result));
const __m128i _Elem2 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(_First2_ch + _Result));
const auto _Bingo =
static_cast<unsigned int>(_mm_movemask_epi8(_Traits::_Cmp_sse(_Elem1, _Elem2))) ^ 0xFFFF;
if (_Bingo != 0) {
unsigned long _Offset;
_BitScanForward(&_Offset, _Bingo); // lgtm [cpp/conditionallyuninitializedvariable]
return (_Result + _Offset) / sizeof(_Ty);
}
}

_Result /= sizeof(_Ty);
}
#endif // !defined(_M_ARM64EC)
const auto _First1_el = static_cast<const _Ty*>(_First1);
const auto _First2_el = static_cast<const _Ty*>(_First2);

for (; _Result != _Count; ++_Result) {
if (_First1_el[_Result] != _First2_el[_Result]) {
break;
}
}

return _Result;
}
} // unnamed namespace

extern "C" {
Expand Down Expand Up @@ -2172,6 +2251,26 @@ const void* __stdcall __std_find_first_of_trivial_2(
return __std_find_first_of_trivial_impl<uint16_t>(_First1, _Last1, _First2, _Last2);
}

__declspec(noalias) size_t
__stdcall __std_mismatch_1(const void* const _First1, const void* const _First2, const size_t _Count) noexcept {
return __std_mismatch_impl<_Find_traits_1, uint8_t>(_First1, _First2, _Count);
}

__declspec(noalias) size_t
__stdcall __std_mismatch_2(const void* const _First1, const void* const _First2, const size_t _Count) noexcept {
return __std_mismatch_impl<_Find_traits_2, uint16_t>(_First1, _First2, _Count);
}

__declspec(noalias) size_t
__stdcall __std_mismatch_4(const void* const _First1, const void* const _First2, const size_t _Count) noexcept {
return __std_mismatch_impl<_Find_traits_4, uint32_t>(_First1, _First2, _Count);
}

__declspec(noalias) size_t
__stdcall __std_mismatch_8(const void* const _First1, const void* const _First2, const size_t _Count) noexcept {
return __std_mismatch_impl<_Find_traits_8, uint64_t>(_First1, _First2, _Count);
}

} // extern "C"

#ifndef _M_ARM64EC
Expand Down
Loading

0 comments on commit ffd735a

Please sign in to comment.