Skip to content
51 changes: 34 additions & 17 deletions stl/inc/limits
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,11 @@ _NODISCARD constexpr int _Popcount_fallback(_Ty _Val) noexcept {
return static_cast<int>(_Val >> (_Digits - 8));
}

enum class _Countr_zero_assumption {
_Possibly_zero,
_Nonzero,
};

#if defined(_M_IX86) || (defined(_M_X64) && !defined(_M_ARM64EC))
extern "C" {
extern int __isa_available;
Expand All @@ -1063,16 +1068,18 @@ extern int __isa_available;
#endif // __clang__
}

template <class _Ty>
template <_Countr_zero_assumption _Assumption, class _Ty>
_NODISCARD int _Countr_zero_tzcnt(const _Ty _Val) noexcept {
constexpr int _Digits = numeric_limits<_Ty>::digits;
constexpr _Ty _Max = (numeric_limits<_Ty>::max) ();

if constexpr (_Digits <= 32) {
if constexpr (_Digits <= 16 && _Assumption == _Countr_zero_assumption::_Possibly_zero) {
// Intended widening to int. This operation means that a narrow 0 will widen
// to 0xFFFF....FFFF0... instead of 0. We need this to avoid counting all the zeros
// of the wider type.
constexpr _Ty _Max = (numeric_limits<_Ty>::max) ();
return static_cast<int>(_TZCNT_U32(static_cast<unsigned int>(~_Max | _Val)));
} else if constexpr (_Digits <= 32) {
return static_cast<int>(_TZCNT_U32(_Val));
} else {
#ifdef _M_IX86
const auto _Low = static_cast<unsigned int>(_Val);
Expand All @@ -1088,19 +1095,21 @@ _NODISCARD int _Countr_zero_tzcnt(const _Ty _Val) noexcept {
}
}

template <class _Ty>
template <_Countr_zero_assumption _Assumption = _Countr_zero_assumption::_Possibly_zero, class _Ty>
_NODISCARD int _Countr_zero_bsf(const _Ty _Val) noexcept {
constexpr int _Digits = numeric_limits<_Ty>::digits;
constexpr _Ty _Max = (numeric_limits<_Ty>::max) ();

unsigned long _Result;
if constexpr (_Digits <= 32) {
unsigned char _Bsf_return;

if constexpr (_Digits <= 16 && _Assumption == _Countr_zero_assumption::_Possibly_zero) {
// Intended widening to int. This operation means that a narrow 0 will widen
// to 0xFFFF....FFFF0... instead of 0. We need this to avoid counting all the zeros
// of the wider type.
if (!_BitScanForward(&_Result, static_cast<unsigned int>(~_Max | _Val))) {
return _Digits;
}
constexpr _Ty _Max = (numeric_limits<_Ty>::max) ();
_Bsf_return = _BitScanForward(&_Result, static_cast<unsigned int>(~_Max | _Val));
} else if constexpr (_Digits <= 32) {
_Bsf_return = _BitScanForward(&_Result, _Val);
} else {
#ifdef _M_IX86
const auto _Low = static_cast<unsigned int>(_Val);
Expand All @@ -1115,24 +1124,31 @@ _NODISCARD int _Countr_zero_bsf(const _Ty _Val) noexcept {
return static_cast<int>(_Result + 32);
}
#else // ^^^ _M_IX86 / !_M_IX86 vvv
if (!_BitScanForward64(&_Result, _Val)) {
_Bsf_return = _BitScanForward64(&_Result, _Val);
#endif // _M_IX86
}

if constexpr (_Digits >= 32 && _Assumption == _Countr_zero_assumption::_Possibly_zero) {
if (!_Bsf_return) {
return _Digits;
}
#endif // _M_IX86
} else {
(void) _Bsf_return;
}

return static_cast<int>(_Result);
}

template <class _Ty>
template <_Countr_zero_assumption _Assumption, class _Ty>
_NODISCARD int _Checked_x86_x64_countr_zero(const _Ty _Val) noexcept {
#ifdef __AVX2__
return _Countr_zero_tzcnt(_Val);
return _Countr_zero_tzcnt<_Assumption>(_Val);
#else // __AVX2__
const bool _Definitely_have_tzcnt = __isa_available >= __ISA_AVAILABLE_AVX2;
if (_Definitely_have_tzcnt) {
return _Countr_zero_tzcnt(_Val);
return _Countr_zero_tzcnt<_Assumption>(_Val);
} else {
return _Countr_zero_bsf(_Val);
return _Countr_zero_bsf<_Assumption>(_Val);
}
#endif // __AVX2__
}
Expand Down Expand Up @@ -1188,12 +1204,13 @@ template <class _Ty>
constexpr bool _Is_standard_unsigned_integer =
_Is_any_of_v<remove_cv_t<_Ty>, unsigned char, unsigned short, unsigned int, unsigned long, unsigned long long>;

template <class _Ty, enable_if_t<_Is_standard_unsigned_integer<_Ty>, int> = 0>
template <_Countr_zero_assumption _Assumption = _Countr_zero_assumption::_Possibly_zero, //
class _Ty, enable_if_t<_Is_standard_unsigned_integer<_Ty>, int> = 0>
_NODISCARD constexpr int _Countr_zero(const _Ty _Val) noexcept {
#if defined(_M_IX86) || (defined(_M_X64) && !defined(_M_ARM64EC))
#if _HAS_CXX20
if (!_STD is_constant_evaluated()) {
return _Checked_x86_x64_countr_zero(_Val);
return _Checked_x86_x64_countr_zero<_Assumption>(_Val);
}
#endif // _HAS_CXX20
#endif // defined(_M_IX86) || (defined(_M_X64) && !defined(_M_ARM64EC))
Expand Down
7 changes: 4 additions & 3 deletions stl/inc/numeric
Original file line number Diff line number Diff line change
Expand Up @@ -560,13 +560,14 @@ _NODISCARD constexpr common_type_t<_Mt, _Nt> gcd(const _Mt _Mx, const _Nt _Nx) n
return static_cast<_Common>(_Mx_magnitude);
}

const auto _Mx_trailing_zeroes = static_cast<unsigned long>(_Countr_zero(_Mx_magnitude));
constexpr auto _Nonzero = _Countr_zero_assumption::_Nonzero;
const auto _Mx_trailing_zeroes = static_cast<unsigned long>(_Countr_zero<_Nonzero>(_Mx_magnitude));
const auto _Common_factors_of_2 =
(_STD min) (_Mx_trailing_zeroes, static_cast<unsigned long>(_Countr_zero(_Nx_magnitude)));
(_STD min) (_Mx_trailing_zeroes, static_cast<unsigned long>(_Countr_zero<_Nonzero>(_Nx_magnitude)));
_Nx_magnitude >>= _Common_factors_of_2;
_Mx_magnitude >>= _Mx_trailing_zeroes;
do {
_Nx_magnitude >>= static_cast<unsigned long>(_Countr_zero(_Nx_magnitude));
_Nx_magnitude >>= static_cast<unsigned long>(_Countr_zero<_Nonzero>(_Nx_magnitude));
if (_Mx_magnitude > _Nx_magnitude) {
_Common_unsigned _Temp = _Mx_magnitude;
_Mx_magnitude = _Nx_magnitude;
Expand Down
3 changes: 2 additions & 1 deletion stl/inc/vector
Original file line number Diff line number Diff line change
Expand Up @@ -1814,7 +1814,8 @@ struct _Vbase_compare_three_way {
#endif // ^^^ !defined(__cpp_lib_concepts) ^^^
}

const int _Bit_index = _Countr_zero(_Differing_bits); // number of least significant bits that match
constexpr auto _Nonzero = _Countr_zero_assumption::_Nonzero;
const int _Bit_index = _Countr_zero<_Nonzero>(_Differing_bits); // number of least significant bits that match
_STL_INTERNAL_CHECK(_Bit_index < _VBITS); // because we return early for equality

const _Vbase _Mask = _Vbase{1} << _Bit_index; // selects the least significant bit that differs
Expand Down