diff --git a/stl/inc/limits b/stl/inc/limits index 9421c6d3088..2456f48ccf8 100644 --- a/stl/inc/limits +++ b/stl/inc/limits @@ -1051,6 +1051,11 @@ _NODISCARD constexpr int _Popcount_fallback(_Ty _Val) noexcept { return static_cast(_Val >> (_Digits - 8)); } +enum class _Countr_zero_assumption { + _Possibly_zero, + _Nonzero, +}; + #if defined(_M_IX86) || (defined(_M_X64) && !defined(_M_ARM64EC)) extern "C" { extern int __isa_available; @@ -1063,16 +1068,18 @@ extern int __isa_available; #endif // __clang__ } -template +template <_Countr_zero_assumption _Assumption, class _Ty> _NODISCARD int _Countr_zero_tzcnt(const _Ty _Val) noexcept { constexpr int _Digits = numeric_limits<_Ty>::digits; - constexpr _Ty _Max = (numeric_limits<_Ty>::max) (); - if constexpr (_Digits <= 32) { + if constexpr (_Digits <= 16 && _Assumption == _Countr_zero_assumption::_Possibly_zero) { // Intended widening to int. This operation means that a narrow 0 will widen // to 0xFFFF....FFFF0... instead of 0. We need this to avoid counting all the zeros // of the wider type. + constexpr _Ty _Max = (numeric_limits<_Ty>::max) (); return static_cast(_TZCNT_U32(static_cast(~_Max | _Val))); + } else if constexpr (_Digits <= 32) { + return static_cast(_TZCNT_U32(_Val)); } else { #ifdef _M_IX86 const auto _Low = static_cast(_Val); @@ -1088,19 +1095,21 @@ _NODISCARD int _Countr_zero_tzcnt(const _Ty _Val) noexcept { } } -template +template <_Countr_zero_assumption _Assumption = _Countr_zero_assumption::_Possibly_zero, class _Ty> _NODISCARD int _Countr_zero_bsf(const _Ty _Val) noexcept { constexpr int _Digits = numeric_limits<_Ty>::digits; - constexpr _Ty _Max = (numeric_limits<_Ty>::max) (); unsigned long _Result; - if constexpr (_Digits <= 32) { + unsigned char _Bsf_return; + + if constexpr (_Digits <= 16 && _Assumption == _Countr_zero_assumption::_Possibly_zero) { // Intended widening to int. This operation means that a narrow 0 will widen // to 0xFFFF....FFFF0... instead of 0. We need this to avoid counting all the zeros // of the wider type. - if (!_BitScanForward(&_Result, static_cast(~_Max | _Val))) { - return _Digits; - } + constexpr _Ty _Max = (numeric_limits<_Ty>::max) (); + _Bsf_return = _BitScanForward(&_Result, static_cast(~_Max | _Val)); + } else if constexpr (_Digits <= 32) { + _Bsf_return = _BitScanForward(&_Result, _Val); } else { #ifdef _M_IX86 const auto _Low = static_cast(_Val); @@ -1115,24 +1124,31 @@ _NODISCARD int _Countr_zero_bsf(const _Ty _Val) noexcept { return static_cast(_Result + 32); } #else // ^^^ _M_IX86 / !_M_IX86 vvv - if (!_BitScanForward64(&_Result, _Val)) { + _Bsf_return = _BitScanForward64(&_Result, _Val); +#endif // _M_IX86 + } + + if constexpr (_Digits >= 32 && _Assumption == _Countr_zero_assumption::_Possibly_zero) { + if (!_Bsf_return) { return _Digits; } -#endif // _M_IX86 + } else { + (void) _Bsf_return; } + return static_cast(_Result); } -template +template <_Countr_zero_assumption _Assumption, class _Ty> _NODISCARD int _Checked_x86_x64_countr_zero(const _Ty _Val) noexcept { #ifdef __AVX2__ - return _Countr_zero_tzcnt(_Val); + return _Countr_zero_tzcnt<_Assumption>(_Val); #else // __AVX2__ const bool _Definitely_have_tzcnt = __isa_available >= __ISA_AVAILABLE_AVX2; if (_Definitely_have_tzcnt) { - return _Countr_zero_tzcnt(_Val); + return _Countr_zero_tzcnt<_Assumption>(_Val); } else { - return _Countr_zero_bsf(_Val); + return _Countr_zero_bsf<_Assumption>(_Val); } #endif // __AVX2__ } @@ -1188,12 +1204,13 @@ template constexpr bool _Is_standard_unsigned_integer = _Is_any_of_v, unsigned char, unsigned short, unsigned int, unsigned long, unsigned long long>; -template , int> = 0> +template <_Countr_zero_assumption _Assumption = _Countr_zero_assumption::_Possibly_zero, // + class _Ty, enable_if_t<_Is_standard_unsigned_integer<_Ty>, int> = 0> _NODISCARD constexpr int _Countr_zero(const _Ty _Val) noexcept { #if defined(_M_IX86) || (defined(_M_X64) && !defined(_M_ARM64EC)) #if _HAS_CXX20 if (!_STD is_constant_evaluated()) { - return _Checked_x86_x64_countr_zero(_Val); + return _Checked_x86_x64_countr_zero<_Assumption>(_Val); } #endif // _HAS_CXX20 #endif // defined(_M_IX86) || (defined(_M_X64) && !defined(_M_ARM64EC)) diff --git a/stl/inc/numeric b/stl/inc/numeric index 1d6c90bf094..4b8c74af988 100644 --- a/stl/inc/numeric +++ b/stl/inc/numeric @@ -560,13 +560,14 @@ _NODISCARD constexpr common_type_t<_Mt, _Nt> gcd(const _Mt _Mx, const _Nt _Nx) n return static_cast<_Common>(_Mx_magnitude); } - const auto _Mx_trailing_zeroes = static_cast(_Countr_zero(_Mx_magnitude)); + constexpr auto _Nonzero = _Countr_zero_assumption::_Nonzero; + const auto _Mx_trailing_zeroes = static_cast(_Countr_zero<_Nonzero>(_Mx_magnitude)); const auto _Common_factors_of_2 = - (_STD min) (_Mx_trailing_zeroes, static_cast(_Countr_zero(_Nx_magnitude))); + (_STD min) (_Mx_trailing_zeroes, static_cast(_Countr_zero<_Nonzero>(_Nx_magnitude))); _Nx_magnitude >>= _Common_factors_of_2; _Mx_magnitude >>= _Mx_trailing_zeroes; do { - _Nx_magnitude >>= static_cast(_Countr_zero(_Nx_magnitude)); + _Nx_magnitude >>= static_cast(_Countr_zero<_Nonzero>(_Nx_magnitude)); if (_Mx_magnitude > _Nx_magnitude) { _Common_unsigned _Temp = _Mx_magnitude; _Mx_magnitude = _Nx_magnitude; diff --git a/stl/inc/vector b/stl/inc/vector index 130a9c9544e..bc896f47e99 100644 --- a/stl/inc/vector +++ b/stl/inc/vector @@ -1814,7 +1814,8 @@ struct _Vbase_compare_three_way { #endif // ^^^ !defined(__cpp_lib_concepts) ^^^ } - const int _Bit_index = _Countr_zero(_Differing_bits); // number of least significant bits that match + constexpr auto _Nonzero = _Countr_zero_assumption::_Nonzero; + const int _Bit_index = _Countr_zero<_Nonzero>(_Differing_bits); // number of least significant bits that match _STL_INTERNAL_CHECK(_Bit_index < _VBITS); // because we return early for equality const _Vbase _Mask = _Vbase{1} << _Bit_index; // selects the least significant bit that differs