Skip to content

Commit

Permalink
Make more use of llvm STLExtras (#4668)
Browse files Browse the repository at this point in the history
This is essentially the result of looking at `.begin()` uses. We also
frequently do `std::shuffle`, but unfortunately STLExtras doesn't
provide a wrapper for that.
  • Loading branch information
jonmeow authored Dec 11, 2024
1 parent 47285b6 commit 61c0a8b
Show file tree
Hide file tree
Showing 15 changed files with 52 additions and 64 deletions.
2 changes: 1 addition & 1 deletion common/array_stack.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class ArrayStack {
auto AppendToTop(llvm::ArrayRef<ValueT> values) -> void {
CARBON_CHECK(!array_offsets_.empty(),
"Must call PushArray before PushValues.");
values_.append(values.begin(), values.end());
llvm::append_range(values_, values);
}

// Returns the current number of values in all arrays.
Expand Down
7 changes: 3 additions & 4 deletions common/command_line.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1070,10 +1070,9 @@ auto Parser::FinalizeParsedOptions() -> ErrorOr<Success> {
// Sort the missing arguments by name to provide a stable and deterministic
// error message. We know there can't be duplicate names because these came
// from a may keyed on the name, so this provides a total ordering.
std::sort(missing_options.begin(), missing_options.end(),
[](const Arg* lhs, const Arg* rhs) {
return lhs->info.name < rhs->info.name;
});
llvm::sort(missing_options, [](const Arg* lhs, const Arg* rhs) {
return lhs->info.name < rhs->info.name;
});

std::string error_str = "required options not provided: ";
llvm::raw_string_ostream error(error_str);
Expand Down
33 changes: 14 additions & 19 deletions common/hashing_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -592,9 +592,9 @@ auto FindBitRangeCollisions(llvm::ArrayRef<HashedValue<T>> hashes)

// Now we sort by the extracted bit sequence so we can efficiently scan for
// colliding bit patterns.
std::sort(
bits_and_indices.begin(), bits_and_indices.end(),
[](const auto& lhs, const auto& rhs) { return lhs.bits < rhs.bits; });
llvm::sort(bits_and_indices, [](const auto& lhs, const auto& rhs) {
return lhs.bits < rhs.bits;
});

// Scan the sorted bit sequences we've extracted looking for collisions. We
// count the total collisions, but we also track the number of individual
Expand Down Expand Up @@ -635,16 +635,15 @@ auto FindBitRangeCollisions(llvm::ArrayRef<HashedValue<T>> hashes)
}

// Sort by collision count for each hash.
std::sort(bits_and_indices.begin(), bits_and_indices.end(),
[&](const auto& lhs, const auto& rhs) {
return collision_counts[collision_map[lhs.index]] <
collision_counts[collision_map[rhs.index]];
});
llvm::sort(bits_and_indices, [&](const auto& lhs, const auto& rhs) {
return collision_counts[collision_map[lhs.index]] <
collision_counts[collision_map[rhs.index]];
});

// And compute the median and max.
int median = collision_counts
[collision_map[bits_and_indices[bits_and_indices.size() / 2].index]];
int max = *std::max_element(collision_counts.begin(), collision_counts.end());
int max = *llvm::max_element(collision_counts);
CARBON_CHECK(max ==
collision_counts[collision_map[bits_and_indices.back().index]]);
return {.total = total, .median = median, .max = max};
Expand Down Expand Up @@ -672,11 +671,9 @@ auto AllByteStringsHashedAndSorted() {
hashes.push_back({HashValue(s, TestSeed), s});
}

std::sort(hashes.begin(), hashes.end(),
[](const HashedString& lhs, const HashedString& rhs) {
return static_cast<uint64_t>(lhs.hash) <
static_cast<uint64_t>(rhs.hash);
});
llvm::sort(hashes, [](const HashedString& lhs, const HashedString& rhs) {
return static_cast<uint64_t>(lhs.hash) < static_cast<uint64_t>(rhs.hash);
});
CheckNoDuplicateValues(hashes);

return hashes;
Expand Down Expand Up @@ -832,11 +829,9 @@ struct SparseHashTest : ::testing::Test {
}
}

std::sort(hashes.begin(), hashes.end(),
[](const HashedString& lhs, const HashedString& rhs) {
return static_cast<uint64_t>(lhs.hash) <
static_cast<uint64_t>(rhs.hash);
});
llvm::sort(hashes, [](const HashedString& lhs, const HashedString& rhs) {
return static_cast<uint64_t>(lhs.hash) < static_cast<uint64_t>(rhs.hash);
});
CheckNoDuplicateValues(hashes);

return hashes;
Expand Down
8 changes: 4 additions & 4 deletions common/raw_hashtable_benchmark_helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,10 +349,10 @@ auto DumpHashStatistics(llvm::ArrayRef<T> keys) -> void {
grouped_key_indices[hash_index].push_back(i);
}
ssize_t max_group_index =
std::max_element(grouped_key_indices.begin(), grouped_key_indices.end(),
[](const auto& lhs, const auto& rhs) {
return lhs.size() < rhs.size();
}) -
llvm::max_element(grouped_key_indices,
[](const auto& lhs, const auto& rhs) {
return lhs.size() < rhs.size();
}) -
grouped_key_indices.begin();

// If the max number of collisions on the index is less than or equal to the
Expand Down
6 changes: 3 additions & 3 deletions testing/base/source_gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ auto SourceGen::ClassGenState::BuildClassAndTypeNames(
int num_declared_types =
num_types * type_use_params.declared_types_weight / type_weight_sum;
for ([[maybe_unused]] auto _ : llvm::seq(num_declared_types / num_classes)) {
type_names_.append(class_names_.begin(), class_names_.end());
llvm::append_range(type_names_, class_names_);
}
// Now append the remainder number of class names. This is where the class
// names being un-shuffled is essential. We're going to have one extra
Expand Down Expand Up @@ -389,8 +389,8 @@ auto SourceGen::GetIdentifiers(int number, int min_length, int max_length,
number, min_length, max_length, uniform,
[this](int length, int length_count,
llvm::SmallVectorImpl<llvm::StringRef>& dest) {
auto length_idents = GetSingleLengthIdentifiers(length, length_count);
dest.append(length_idents.begin(), length_idents.end());
llvm::append_range(dest,
GetSingleLengthIdentifiers(length, length_count));
});

return idents;
Expand Down
4 changes: 2 additions & 2 deletions testing/file_test/autoupdate.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ class FileTestAutoupdater {
// initialization.
stdout_(BuildCheckLines(stdout, "STDOUT")),
stderr_(BuildCheckLines(stderr, "STDERR")),
any_attached_stdout_lines_(std::any_of(
stdout_.lines.begin(), stdout_.lines.end(),
any_attached_stdout_lines_(llvm::any_of(
stdout_.lines,
[&](const CheckLine& line) { return line.line_number() != -1; })),
non_check_line_(non_check_lines_.begin()) {
for (const auto& replacement : line_number_replacements_) {
Expand Down
2 changes: 1 addition & 1 deletion toolchain/check/deduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ DeductionContext::DeductionContext(Context& context, SemIR::LocId loc_id,
// to substitute them into the function declaration.
auto args = context.inst_blocks().Get(
context.specifics().Get(enclosing_specific_id).args_id);
std::copy(args.begin(), args.end(), result_arg_ids_.begin());
llvm::copy(args, result_arg_ids_.begin());

// TODO: Subst is linear in the length of the substitutions list. Change
// it so we can pass in an array mapping indexes to substitutions instead.
Expand Down
12 changes: 5 additions & 7 deletions toolchain/check/impl_lookup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ static auto FindAssociatedImportIRs(Context& context,
// Push the contents of an instruction block onto our worklist.
auto push_block = [&](SemIR::InstBlockId block_id) {
if (block_id.is_valid()) {
auto block = context.inst_blocks().Get(block_id);
worklist.append(block.begin(), block.end());
llvm::append_range(worklist, context.inst_blocks().Get(block_id));
}
};

Expand Down Expand Up @@ -102,11 +101,10 @@ static auto FindAssociatedImportIRs(Context& context,
}

// Deduplicate.
std::sort(result.begin(), result.end(),
[](SemIR::ImportIRId a, SemIR::ImportIRId b) {
return a.index < b.index;
});
result.erase(std::unique(result.begin(), result.end()), result.end());
llvm::sort(result, [](SemIR::ImportIRId a, SemIR::ImportIRId b) {
return a.index < b.index;
});
result.erase(llvm::unique(result), result.end());

return result;
}
Expand Down
5 changes: 2 additions & 3 deletions toolchain/check/scope_stack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,8 @@ auto ScopeStack::LookupInLexicalScopes(SemIR::NameId name_id)

// Find the first non-lexical scope that is within the scope of the lexical
// lookup result.
auto* first_non_lexical_scope = std::lower_bound(
non_lexical_scope_stack_.begin(), non_lexical_scope_stack_.end(),
lexical_results.back().scope_index,
auto* first_non_lexical_scope = llvm::lower_bound(
non_lexical_scope_stack_, lexical_results.back().scope_index,
[](const NonLexicalScope& scope, ScopeIndex index) {
return scope.scope_index < index;
});
Expand Down
5 changes: 2 additions & 3 deletions toolchain/lex/lex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1555,9 +1555,8 @@ auto Lexer::DiagnoseAndFixMismatchedBrackets() -> void {
}

// Find the innermost matching opening symbol.
auto opening_it = std::find_if(
open_groups_.rbegin(), open_groups_.rend(),
[&](TokenIndex opening_token) {
auto opening_it = llvm::find_if(
llvm::reverse(open_groups_), [&](TokenIndex opening_token) {
return buffer_.GetTokenInfo(opening_token).kind().closing_symbol() ==
kind;
});
Expand Down
5 changes: 2 additions & 3 deletions toolchain/lex/numeric_literal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,8 @@ static auto ParseInt(llvm::StringRef digits, NumericLiteral::Radix radix,
llvm::SmallString<32> cleaned;
if (needs_cleaning) {
cleaned.reserve(digits.size());
std::remove_copy_if(digits.begin(), digits.end(),
std::back_inserter(cleaned),
[](char c) { return c == '_' || c == '.'; });
llvm::copy_if(digits, std::back_inserter(cleaned),
[](char c) { return c != '_' && c != '.'; });
digits = cleaned;
}

Expand Down
11 changes: 5 additions & 6 deletions toolchain/lex/tokenized_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,10 +317,9 @@ auto TokenizedBuffer::PrintToken(llvm::raw_ostream& output_stream,
auto TokenizedBuffer::FindLineIndex(int32_t byte_offset) const -> LineIndex {
CARBON_DCHECK(!line_infos_.empty());
const auto* line_it =
std::partition_point(line_infos_.begin(), line_infos_.end(),
[byte_offset](LineInfo line_info) {
return line_info.start <= byte_offset;
});
llvm::partition_point(line_infos_, [byte_offset](LineInfo line_info) {
return line_info.start <= byte_offset;
});
--line_it;

// If this isn't the first line but it starts past the end of the source, then
Expand Down Expand Up @@ -386,8 +385,8 @@ auto TokenizedBuffer::SourceBufferDiagnosticConverter::ConvertLoc(
int32_t offset = loc - buffer_->source_->text().begin();

// Find the first line starting after the given location.
const auto* next_line_it = std::partition_point(
buffer_->line_infos_.begin(), buffer_->line_infos_.end(),
const auto* next_line_it = llvm::partition_point(
buffer_->line_infos_,
[offset](const LineInfo& line) { return line.start <= offset; });

// Step back one line to find the line containing the given position.
Expand Down
2 changes: 1 addition & 1 deletion toolchain/parse/extract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ auto NodeExtractor::MatchesNodeIdOneOf(
*trace_ << "\n";
}
return false;
} else if (std::find(kinds.begin(), kinds.end(), node_kind) == kinds.end()) {
} else if (llvm::find(kinds, node_kind) == kinds.end()) {
if (trace_) {
*trace_ << "NodeIdOneOf error: wrong kind " << node_kind << ", expected ";
trace_kinds();
Expand Down
10 changes: 5 additions & 5 deletions toolchain/sem_ir/facet_type_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
namespace Carbon::SemIR {

template <typename VecT>
static auto SortAndDeduplicate(VecT* vec) -> void {
std::sort(vec->begin(), vec->end());
vec->erase(std::unique(vec->begin(), vec->end()), vec->end());
static auto SortAndDeduplicate(VecT& vec) -> void {
llvm::sort(vec);
vec.erase(llvm::unique(vec), vec.end());
}

auto FacetTypeInfo::Canonicalize() -> void {
SortAndDeduplicate(&impls_constraints);
SortAndDeduplicate(&rewrite_constraints);
SortAndDeduplicate(impls_constraints);
SortAndDeduplicate(rewrite_constraints);
}

auto FacetTypeInfo::Print(llvm::raw_ostream& out) const -> void {
Expand Down
4 changes: 2 additions & 2 deletions toolchain/testing/coverage_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ auto TestKindCoverage(const std::string& manifest_path,

constexpr llvm::StringLiteral Bullet = "\n - ";

std::sort(missing_kinds.begin(), missing_kinds.end());
llvm::sort(missing_kinds);
EXPECT_TRUE(missing_kinds.empty()) << "Some kinds have no tests:" << Bullet
<< llvm::join(missing_kinds, Bullet);

llvm::SmallVector<std::string> unexpected_matches;
covered_kinds.ForEach(
[&](const std::string& match) { unexpected_matches.push_back(match); });
std::sort(unexpected_matches.begin(), unexpected_matches.end());
llvm::sort(unexpected_matches);
EXPECT_TRUE(unexpected_matches.empty())
<< "Matched things that aren't in the kind list:" << Bullet
<< llvm::join(unexpected_matches, Bullet);
Expand Down

0 comments on commit 61c0a8b

Please sign in to comment.