Skip to content

Commit 62226ee

Browse files
committed
Improve BinaryHeap::retain.
It now doesn't fully rebuild the heap, but only the parts that are necessary.
1 parent 88b99de commit 62226ee

File tree

1 file changed

+53
-32
lines changed

1 file changed

+53
-32
lines changed

Diff for: library/alloc/src/collections/binary_heap.rs

+53-32
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,43 @@ impl<T: Ord> BinaryHeap<T> {
652652
unsafe { self.sift_up(start, pos) };
653653
}
654654

655+
/// Rebuild assuming data[0..start] is still a proper heap.
656+
fn rebuild_tail(&mut self, start: usize) {
657+
if start == self.len() {
658+
return;
659+
}
660+
661+
let tail_len = self.len() - start;
662+
663+
#[inline(always)]
664+
fn log2_fast(x: usize) -> usize {
665+
(usize::BITS - x.leading_zeros() - 1) as usize
666+
}
667+
668+
// `rebuild` takes O(self.len()) operations
669+
// and about 2 * self.len() comparisons in the worst case
670+
// while repeating `sift_up` takes O(tail_len * log(start)) operations
671+
// and about 1 * tail_len * log_2(start) comparisons in the worst case,
672+
// assuming start >= tail_len. For larger heaps, the crossover point
673+
// no longer follows this reasoning and was determined empirically.
674+
let better_to_rebuild = if start < tail_len {
675+
true
676+
} else if self.len() <= 2048 {
677+
2 * self.len() < tail_len * log2_fast(start)
678+
} else {
679+
2 * self.len() < tail_len * 11
680+
};
681+
682+
if better_to_rebuild {
683+
self.rebuild();
684+
} else {
685+
for i in start..self.len() {
686+
// SAFETY: The index `i` is always less than self.len().
687+
unsafe { self.sift_up(0, i) };
688+
}
689+
}
690+
}
691+
655692
fn rebuild(&mut self) {
656693
let mut n = self.len() / 2;
657694
while n > 0 {
@@ -689,37 +726,11 @@ impl<T: Ord> BinaryHeap<T> {
689726
swap(self, other);
690727
}
691728

692-
if other.is_empty() {
693-
return;
694-
}
695-
696-
#[inline(always)]
697-
fn log2_fast(x: usize) -> usize {
698-
(usize::BITS - x.leading_zeros() - 1) as usize
699-
}
729+
let start = self.data.len();
700730

701-
// `rebuild` takes O(len1 + len2) operations
702-
// and about 2 * (len1 + len2) comparisons in the worst case
703-
// while `extend` takes O(len2 * log(len1)) operations
704-
// and about 1 * len2 * log_2(len1) comparisons in the worst case,
705-
// assuming len1 >= len2. For larger heaps, the crossover point
706-
// no longer follows this reasoning and was determined empirically.
707-
#[inline]
708-
fn better_to_rebuild(len1: usize, len2: usize) -> bool {
709-
let tot_len = len1 + len2;
710-
if tot_len <= 2048 {
711-
2 * tot_len < len2 * log2_fast(len1)
712-
} else {
713-
2 * tot_len < len2 * 11
714-
}
715-
}
731+
self.data.append(&mut other.data);
716732

717-
if better_to_rebuild(self.len(), other.len()) {
718-
self.data.append(&mut other.data);
719-
self.rebuild();
720-
} else {
721-
self.extend(other.drain());
722-
}
733+
self.rebuild_tail(start);
723734
}
724735

725736
/// Returns an iterator which retrieves elements in heap order.
@@ -770,12 +781,22 @@ impl<T: Ord> BinaryHeap<T> {
770781
/// assert_eq!(heap.into_sorted_vec(), [-10, 2, 4])
771782
/// ```
772783
#[unstable(feature = "binary_heap_retain", issue = "71503")]
773-
pub fn retain<F>(&mut self, f: F)
784+
pub fn retain<F>(&mut self, mut f: F)
774785
where
775786
F: FnMut(&T) -> bool,
776787
{
777-
self.data.retain(f);
778-
self.rebuild();
788+
let mut first_removed = self.len();
789+
let mut i = 0;
790+
self.data.retain(|e| {
791+
let keep = f(e);
792+
if !keep && i < first_removed {
793+
first_removed = i;
794+
}
795+
i += 1;
796+
keep
797+
});
798+
// data[0..first_removed] is untouched, so we only need to rebuild the tail:
799+
self.rebuild_tail(first_removed);
779800
}
780801
}
781802

0 commit comments

Comments
 (0)