Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 54 additions & 16 deletions stl/inc/limits
Original file line number Diff line number Diff line change
Expand Up @@ -1064,41 +1064,79 @@ extern int __isa_available;
}

template <class _Ty>
_NODISCARD int _Checked_x86_x64_countr_zero(const _Ty _Val) noexcept {
_NODISCARD int _Countr_zero_tzcnt(const _Ty _Val) noexcept {
constexpr int _Digits = numeric_limits<_Ty>::digits;
constexpr _Ty _Max = (numeric_limits<_Ty>::max) ();

#ifndef __AVX2__
// Because the widening done below will always give a non-0 value, checking for tzcnt
// is not required for 8-bit and 16-bit since the only difference in behavior between
// bsf and tzcnt is when the value is 0.
if constexpr (_Digits > 16) {
const bool _Definitely_have_tzcnt = __isa_available >= __ISA_AVAILABLE_AVX2;
if (!_Definitely_have_tzcnt && _Val == 0) {
return _Digits;
if constexpr (_Digits <= 32) {
// Intended widening to int. This operation means that a narrow 0 will widen
// to 0xFFFF....FFFF0... instead of 0. We need this to avoid counting all the zeros
// of the wider type.
return static_cast<int>(_TZCNT_U32(static_cast<unsigned int>(~_Max | _Val)));
} else {
#ifdef _M_IX86
const auto _Low = static_cast<unsigned int>(_Val);
if (_Low == 0) {
const unsigned int _High = _Val >> 32;
return static_cast<int>(32 + _TZCNT_U32(_High));
} else {
return static_cast<int>(_TZCNT_U32(_Low));
}
#else // ^^^ _M_IX86 / !_M_IX86 vvv
return static_cast<int>(_TZCNT_U64(_Val));
#endif // _M_IX86
}
#endif // __AVX2__
}

template <class _Ty>
_NODISCARD int _Countr_zero_bsf(const _Ty _Val) noexcept {
constexpr int _Digits = numeric_limits<_Ty>::digits;
constexpr _Ty _Max = (numeric_limits<_Ty>::max) ();

unsigned long _Result;
if constexpr (_Digits <= 32) {
// Intended widening to int. This operation means that a narrow 0 will widen
// to 0xFFFF....FFFF0... instead of 0. We need this to avoid counting all the zeros
// of the wider type.
return static_cast<int>(_TZCNT_U32(static_cast<unsigned int>(~_Max | _Val)));
if (!_BitScanForward(&_Result, static_cast<unsigned int>(~_Max | _Val))) {
return _Digits;
}
} else {
#ifdef _M_IX86
const auto _Low = static_cast<unsigned int>(_Val);
if (_BitScanForward(&_Result, _Low)) {
return static_cast<int>(_Result);
}

const unsigned int _High = _Val >> 32;
const unsigned int _Low = static_cast<unsigned int>(_Val);
if (_Low == 0) {
return 32 + _Checked_x86_x64_countr_zero(_High);
if (!_BitScanForward(&_Result, _High)) {
return _Digits;
} else {
return _Checked_x86_x64_countr_zero(_Low);
return static_cast<int>(_Result + 32);
}
#else // ^^^ _M_IX86 / !_M_IX86 vvv
return static_cast<int>(_TZCNT_U64(_Val));
if (!_BitScanForward64(&_Result, _Val)) {
return _Digits;
}
#endif // _M_IX86
}
return static_cast<int>(_Result);
}

template <class _Ty>
_NODISCARD int _Checked_x86_x64_countr_zero(const _Ty _Val) noexcept {
#ifdef __AVX2__
return _Countr_zero_tzcnt(_Val);
#else // __AVX2__
const bool _Definitely_have_tzcnt = __isa_available >= __ISA_AVAILABLE_AVX2;
if (_Definitely_have_tzcnt) {
return _Countr_zero_tzcnt(_Val);
} else {
return _Countr_zero_bsf(_Val);
}
#endif // __AVX2__
}

#undef _TZCNT_U32
#undef _TZCNT_U64
#endif // defined(_M_IX86) || (defined(_M_X64) && !defined(_M_ARM64EC))
Expand Down
25 changes: 25 additions & 0 deletions tests/std/tests/GH_001103_countl_zero_correctness/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,30 @@ int main() {
assert(_Countl_zero_bsr(static_cast<unsigned long long>(0x0000'0000'0000'0013)) == 59);
assert(_Countl_zero_bsr(static_cast<unsigned long long>(0x8000'0000'0000'0003)) == 0);
assert(_Countl_zero_bsr(static_cast<unsigned long long>(0xF000'0000'0000'0008)) == 0);

assert(_Countr_zero_bsf(static_cast<unsigned char>(0x00)) == 8);
assert(_Countr_zero_bsf(static_cast<unsigned char>(0x13)) == 0);
assert(_Countr_zero_bsf(static_cast<unsigned char>(0x80)) == 7);
assert(_Countr_zero_bsf(static_cast<unsigned char>(0xF8)) == 3);

assert(_Countr_zero_bsf(static_cast<unsigned short>(0x0000)) == 16);
assert(_Countr_zero_bsf(static_cast<unsigned short>(0x0013)) == 0);
assert(_Countr_zero_bsf(static_cast<unsigned short>(0x8000)) == 15);
assert(_Countr_zero_bsf(static_cast<unsigned short>(0xF008)) == 3);

assert(_Countr_zero_bsf(static_cast<unsigned int>(0x0000'0000)) == 32);
assert(_Countr_zero_bsf(static_cast<unsigned int>(0x0000'0013)) == 0);
assert(_Countr_zero_bsf(static_cast<unsigned int>(0x8000'0000)) == 31);
assert(_Countr_zero_bsf(static_cast<unsigned int>(0xF000'0008)) == 3);

assert(_Countr_zero_bsf(static_cast<unsigned long>(0x0000'0000)) == 32);
assert(_Countr_zero_bsf(static_cast<unsigned long>(0x0000'0013)) == 0);
assert(_Countr_zero_bsf(static_cast<unsigned long>(0x8000'0000)) == 31);
assert(_Countr_zero_bsf(static_cast<unsigned long>(0xF000'0008)) == 3);

assert(_Countr_zero_bsf(static_cast<unsigned long long>(0x0000'0000'0000'0000)) == 64);
assert(_Countr_zero_bsf(static_cast<unsigned long long>(0x0000'0000'0000'0013)) == 0);
assert(_Countr_zero_bsf(static_cast<unsigned long long>(0x8000'0000'0000'0000)) == 63);
assert(_Countr_zero_bsf(static_cast<unsigned long long>(0xF000'0000'0000'0008)) == 3);
#endif // ^^^ defined(_M_IX86) || defined(_M_X64) ^^^
}