Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions libcxx/include/__algorithm/nth_element.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <__algorithm/comp_ref_type.h>
#include <__algorithm/iterator_operations.h>
#include <__algorithm/sort.h>
#include <__assert>
#include <__config>
#include <__debug_utils/randomize_range.h>
#include <__iterator/iterator_traits.h>
Expand Down Expand Up @@ -116,10 +117,18 @@ __nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth, _Rando
return;
}
while (true) {
while (!__comp(*__first, *__i))
while (!__comp(*__first, *__i)) {
++__i;
while (__comp(*__first, *--__j))
;
_LIBCPP_ASSERT_UNCATEGORIZED(
__i != __last,
"Would read out of bounds, does your comparator satisfy the strict-weak ordering requirement?");
}
do {
_LIBCPP_ASSERT_UNCATEGORIZED(
__j != __first,
"Would read out of bounds, does your comparator satisfy the strict-weak ordering requirement?");
--__j;
} while (__comp(*__first, *__j));
if (__i >= __j)
break;
_Ops::iter_swap(__i, __j);
Expand All @@ -146,11 +155,19 @@ __nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth, _Rando
while (true)
{
// __m still guards upward moving __i
while (__comp(*__i, *__m))
while (__comp(*__i, *__m)) {
++__i;
_LIBCPP_ASSERT_UNCATEGORIZED(
__i != __last,
"Would read out of bounds, does your comparator satisfy the strict-weak ordering requirement?");
}
// It is now known that a guard exists for downward moving __j
while (!__comp(*--__j, *__m))
;
do {
_LIBCPP_ASSERT_UNCATEGORIZED(
__j != __first,
"Would read out of bounds, does your comparator satisfy the strict-weak ordering requirement?");
--__j;
} while (!__comp(*__j, *__m));
if (__i >= __j)
break;
_Ops::iter_swap(__i, __j);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,24 +50,34 @@
#include "bad_comparator_values.h"
#include "check_assertion.h"

void check_oob_sort_read() {
std::map<std::size_t, std::map<std::size_t, bool>> comparison_results; // terrible for performance, but really convenient
for (auto line : std::views::split(DATA, '\n') | std::views::filter([](auto const& line) { return !line.empty(); })) {
auto values = std::views::split(line, ' ');
auto it = values.begin();
std::size_t left = std::stol(std::string((*it).data(), (*it).size()));
it = std::next(it);
std::size_t right = std::stol(std::string((*it).data(), (*it).size()));
it = std::next(it);
bool result = static_cast<bool>(std::stol(std::string((*it).data(), (*it).size())));
comparison_results[left][right] = result;
}
auto predicate = [&](std::size_t* left, std::size_t* right) {
class ComparisonResults {
public:
ComparisonResults(std::string_view data) {
for (auto line : std::views::split(data, '\n') | std::views::filter([](auto const& line) { return !line.empty(); })) {
auto values = std::views::split(line, ' ');
auto it = values.begin();
std::size_t left = std::stol(std::string((*it).data(), (*it).size()));
it = std::next(it);
std::size_t right = std::stol(std::string((*it).data(), (*it).size()));
it = std::next(it);
bool result = static_cast<bool>(std::stol(std::string((*it).data(), (*it).size())));
comparison_results[left][right] = result;
}
}

bool compare(size_t* left, size_t* right) {
assert(left != nullptr && right != nullptr && "something is wrong with the test");
assert(comparison_results.contains(*left) && comparison_results[*left].contains(*right) && "malformed input data?");
return comparison_results[*left][*right];
};
}

size_t size() const { return comparison_results.size(); }
private:
std::map<std::size_t, std::map<std::size_t, bool>> comparison_results; // terrible for performance, but really convenient
};

void check_oob_sort_read() {
ComparisonResults comparison_results(SORT_DATA);
std::vector<std::unique_ptr<std::size_t>> elements;
std::set<std::size_t*> valid_ptrs;
for (std::size_t i = 0; i != comparison_results.size(); ++i) {
Expand All @@ -81,7 +91,7 @@ void check_oob_sort_read() {
// because we're reading OOB.
assert(valid_ptrs.contains(left));
assert(valid_ptrs.contains(right));
return predicate(left, right);
return comparison_results.compare(left, right);
};

// Check the classic sorting algorithms
Expand Down Expand Up @@ -165,6 +175,39 @@ void check_oob_sort_read() {
}
}

void check_oob_nth_element_read() {
ComparisonResults results(NTH_ELEMENT_DATA);
std::vector<std::unique_ptr<std::size_t>> elements;
std::set<std::size_t*> valid_ptrs;
for (std::size_t i = 0; i != results.size(); ++i) {
elements.push_back(std::make_unique<std::size_t>(i));
valid_ptrs.insert(elements.back().get());
}

auto checked_predicate = [&](size_t* left, size_t* right) {
// If the pointers passed to the comparator are not in the set of pointers we
// set up above, then we're being passed garbage values from the algorithm
// because we're reading OOB.
assert(valid_ptrs.contains(left));
assert(valid_ptrs.contains(right));
return results.compare(left, right);
};

{
std::vector<std::size_t*> copy;
for (auto const& e : elements)
copy.push_back(e.get());
TEST_LIBCPP_ASSERT_FAILURE(std::nth_element(copy.begin(), copy.begin(), copy.end(), checked_predicate), "Would read out of bounds");
}

{
std::vector<std::size_t*> copy;
for (auto const& e : elements)
copy.push_back(e.get());
TEST_LIBCPP_ASSERT_FAILURE(std::ranges::nth_element(copy, copy.begin(), checked_predicate), "Would read out of bounds");
}
}

struct FloatContainer {
float value;
bool operator<(const FloatContainer& other) const {
Expand Down Expand Up @@ -214,6 +257,8 @@ int main(int, char**) {

check_oob_sort_read();

check_oob_nth_element_read();

check_nan_floats();

check_irreflexive();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,74 @@

#include <string_view>

inline constexpr std::string_view DATA = R"(
inline constexpr std::string_view NTH_ELEMENT_DATA = R"(
0 0 0
0 1 0
0 2 0
0 3 0
0 4 1
0 5 0
0 6 0
0 7 0
1 0 0
1 1 0
1 2 0
1 3 1
1 4 1
1 5 1
1 6 1
1 7 1
2 0 1
2 1 1
2 2 1
2 3 1
2 4 1
2 5 1
2 6 1
2 7 1
3 0 1
3 1 1
3 2 1
3 3 1
3 4 1
3 5 1
3 6 1
3 7 1
4 0 1
4 1 1
4 2 1
4 3 1
4 4 1
4 5 1
4 6 1
4 7 1
5 0 1
5 1 1
5 2 1
5 3 1
5 4 1
5 5 1
5 6 1
5 7 1
6 0 1
6 1 1
6 2 1
6 3 1
6 4 1
6 5 1
6 6 1
6 7 1
7 0 1
7 1 1
7 2 1
7 3 1
7 4 1
7 5 1
7 6 1
7 7 1
)";

inline constexpr std::string_view SORT_DATA = R"(
0 0 0
0 1 1
0 2 1
Expand Down