diff --git a/stl/inc/algorithm b/stl/inc/algorithm index 59db8eb1959..6c804c2e194 100644 --- a/stl/inc/algorithm +++ b/stl/inc/algorithm @@ -606,26 +606,33 @@ namespace ranges { _NODISCARD constexpr iter_difference_t<_It> operator()( _It _First, _Se _Last, const _Ty& _Val, _Pj _Proj = {}) const { _Adl_verify_range(_First, _Last); + return _Count_unchecked( + _Get_unwrapped(_STD move(_First)), _Get_unwrapped(_STD move(_Last)), _Val, _Pass_fn(_Proj)); + } - auto _UFirst = _Get_unwrapped(_STD move(_First)); - const auto _ULast = _Get_unwrapped(_STD move(_Last)); - iter_difference_t<_It> _Count = 0; + template + requires indirect_binary_predicate, _Pj>, const _Ty*> + _NODISCARD constexpr range_difference_t<_Rng> operator()(_Rng&& _Range, const _Ty& _Val, _Pj _Proj = {}) const { + return _Count_unchecked(_Ubegin(_Range), _Uend(_Range), _Val, _Pass_fn(_Proj)); + } + // clang-format on + private: + template + _NODISCARD static constexpr iter_difference_t<_It> _Count_unchecked( + _It _First, const _Se _Last, const _Ty& _Val, _Pj _Proj) { + _STL_INTERNAL_STATIC_ASSERT(input_iterator<_It>); + _STL_INTERNAL_STATIC_ASSERT(sentinel_for<_Se, _It>); + _STL_INTERNAL_STATIC_ASSERT(indirect_binary_predicate, const _Ty*>); - for (; _UFirst != _ULast; ++_UFirst) { - if (_STD invoke(_Proj, *_UFirst) == _Val) { + iter_difference_t<_It> _Count = 0; + for (; _First != _Last; ++_First) { + if (_STD invoke(_Proj, *_First) == _Val) { ++_Count; } } return _Count; } - - template - requires indirect_binary_predicate, _Pj>, const _Ty*> - _NODISCARD constexpr range_difference_t<_Rng> operator()(_Rng&& _Range, const _Ty& _Val, _Pj _Proj = {}) const { - return (*this)(_RANGES begin(_Range), _RANGES end(_Range), _Val, _Pass_fn(_Proj)); - } - // clang-format on }; inline constexpr _Count_fn count{_Not_quite_object::_Construct_tag{}}; diff --git a/tests/std/tests/P0896R4_ranges_alg_count/test.cpp b/tests/std/tests/P0896R4_ranges_alg_count/test.cpp index 14912670975..05145d1e87c 100644 --- a/tests/std/tests/P0896R4_ranges_alg_count/test.cpp +++ b/tests/std/tests/P0896R4_ranges_alg_count/test.cpp @@ -2,7 +2,6 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include -#include #include #include #include @@ -10,44 +9,33 @@ #include -constexpr void smoke_test() { - using ranges::count; - using P = std::pair; - std::array const x = {{{0, 99}, {1, 47}, {2, 99}, {3, 47}, {4, 99}}}; - using D = ranges::range_difference_t>; - - { - // Validate range overload - auto result = count(basic_borrowed_range{x}, 99, get_second); - STATIC_ASSERT(std::same_as); - assert(result == 3); - } - { - // Validate iterator + sentinel overload - basic_borrowed_range wrapped_x{x}; - auto result = count(wrapped_x.begin(), wrapped_x.end(), 47, get_second); - STATIC_ASSERT(std::same_as); - assert(result == 2); - } -} - -int main() { - STATIC_ASSERT((smoke_test(), true)); - smoke_test(); -} +using namespace std; +using P = pair; struct instantiator { - template - static void call(In&& in = {}) { - ranges::range_value_t const value{}; - (void) ranges::count(in, value); - - struct type { - bool operator==(type const&) const = default; - }; - using Projection = type (*)(std::iter_common_reference_t>); - (void) ranges::count(in, type{}, Projection{}); + static constexpr P input[5] = {{0, 99}, {1, 47}, {2, 99}, {3, 47}, {4, 99}}; + + template + static constexpr void call() { + using ranges::count; + { // Validate iterator + sentinel overload + Read wrapped_input{input}; + + auto result = count(wrapped_input.begin(), wrapped_input.end(), 47, get_second); + STATIC_ASSERT(same_as>); + assert(result == 2); + } + { // Validate range overload + Read wrapped_input{input}; + + auto result = count(wrapped_input, 99, get_second); + STATIC_ASSERT(same_as>); + assert(result == 3); + } } }; -template void test_in(); +int main() { + STATIC_ASSERT((test_in(), true)); + test_in(); +}