diff --git a/stl/inc/algorithm b/stl/inc/algorithm index 59db8eb195..75cf081c49 100644 --- a/stl/inc/algorithm +++ b/stl/inc/algorithm @@ -666,23 +666,33 @@ namespace ranges { indirect_unary_predicate> _Pr> _NODISCARD constexpr iter_difference_t<_It> operator()(_It _First, _Se _Last, _Pr _Pred, _Pj _Proj = {}) const { _Adl_verify_range(_First, _Last); - auto _UFirst = _Get_unwrapped(_STD move(_First)); - const auto _ULast = _Get_unwrapped(_STD move(_Last)); + return _Count_if_unchecked( + _Get_unwrapped(_STD move(_First)), _Get_unwrapped(_STD move(_Last)), _Pass_fn(_Pred), _Pass_fn(_Proj)); + } + + template , _Pj>> _Pr> + _NODISCARD constexpr range_difference_t<_Rng> operator()(_Rng&& _Range, _Pr _Pred, _Pj _Proj = {}) const { + return _Count_if_unchecked(_Ubegin(_Range), _Uend(_Range), _Pass_fn(_Pred), _Pass_fn(_Proj)); + } + + private: + template + _NODISCARD static constexpr iter_difference_t<_It> _Count_if_unchecked( + _It _First, const _Se _Last, _Pr _Pred, _Pj _Proj) { + _STL_INTERNAL_STATIC_ASSERT(input_iterator<_It>); + _STL_INTERNAL_STATIC_ASSERT(sentinel_for<_Se, _It>); + _STL_INTERNAL_STATIC_ASSERT(indirect_unary_predicate<_Pr, projected<_It, _Pj>>); + iter_difference_t<_It> _Count = 0; - for (; _UFirst != _ULast; ++_UFirst) { - if (_STD invoke(_Pred, _STD invoke(_Proj, *_UFirst))) { + for (; _First != _Last; ++_First) { + if (_STD invoke(_Pred, _STD invoke(_Proj, *_First))) { ++_Count; } } return _Count; } - - template , _Pj>> _Pr> - _NODISCARD constexpr range_difference_t<_Rng> operator()(_Rng&& _Range, _Pr _Pred, _Pj _Proj = {}) const { - return (*this)(_RANGES begin(_Range), _RANGES end(_Range), _Pass_fn(_Pred), _Pass_fn(_Proj)); - } }; inline constexpr _Count_if_fn count_if{_Not_quite_object::_Construct_tag{}}; diff --git a/tests/std/tests/P0896R4_ranges_alg_count_if/test.cpp b/tests/std/tests/P0896R4_ranges_alg_count_if/test.cpp index 70d0e51f36..33e3ff2670 100644 --- a/tests/std/tests/P0896R4_ranges_alg_count_if/test.cpp +++ b/tests/std/tests/P0896R4_ranges_alg_count_if/test.cpp @@ -2,7 +2,6 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include -#include #include #include #include @@ -13,41 +12,33 @@ constexpr auto is_even = [](auto const& x) { return x % 2 == 0; }; constexpr auto is_odd = [](auto const& x) { return x % 2 != 0; }; -constexpr void smoke_test() { - using ranges::count_if; - using P = std::pair; - std::array const x = {{{0, 47}, {1, 99}, {2, 99}, {3, 47}, {4, 99}}}; - using D = ranges::range_difference_t>; - - { - // Validate range overload - auto result = count_if(basic_borrowed_range{x}, is_even, get_first); - STATIC_ASSERT(std::same_as); - assert(result == 3); - } - { - // Validate iterator + sentinel overload - basic_borrowed_range wrapped_x{x}; - auto result = count_if(wrapped_x.begin(), wrapped_x.end(), is_odd, get_first); - STATIC_ASSERT(std::same_as); - assert(result == 2); - } -} - -int main() { - STATIC_ASSERT((smoke_test(), true)); - smoke_test(); -} +using namespace std; +using P = pair; struct instantiator { - template - static void call(In&& in = {}) { - using ranges::iterator_t; - using I = iterator_t; - - (void) ranges::count_if(in, UnaryPredicateFor{}); - (void) ranges::count_if(in, ProjectedUnaryPredicate<>{}, ProjectionFor{}); + static constexpr P input[5] = {{0, 47}, {1, 99}, {2, 99}, {3, 47}, {4, 99}}; + + template + static constexpr void call() { + using ranges::count_if; + { // Validate iterator + sentinel overload + Read wrapped_input{input}; + + auto result = count_if(wrapped_input.begin(), wrapped_input.end(), is_odd, get_first); + STATIC_ASSERT(same_as>); + assert(result == 2); + } + { // Validate range overload + Read wrapped_input{input}; + + auto result = count_if(wrapped_input, is_even, get_first); + STATIC_ASSERT(same_as>); + assert(result == 3); + } } }; -template void test_in(); +int main() { + STATIC_ASSERT((test_in(), true)); + test_in(); +}