Skip to content

Commit f4a8cf0

Browse files
committed
Auto merge of rust-lang#78681 - m-ou-se:binary-heap-retain, r=Amanieu
Improve rebuilding behaviour of BinaryHeap::retain. This changes `BinaryHeap::retain` such that it doesn't always fully rebuild the heap, but only rebuilds the parts for which that's necessary. This makes use of the fact that retain gives out `&T`s and not `&mut T`s. Retaining every element or removing only elements at the end results in no rebuilding at all. Retaining most elements results in only reordering the elements that got moved (those after the first removed element), using the same logic as was already used for `append`. cc `@KodrAus` `@sfackler` - We briefly discussed this possibility in the meeting last week while we talked about stabilization of this function (rust-lang#71503).
2 parents 7f4afdf + f5d72ab commit f4a8cf0

File tree

2 files changed

+69
-35
lines changed

2 files changed

+69
-35
lines changed

library/alloc/src/collections/binary_heap.rs

+53-32
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,43 @@ impl<T: Ord> BinaryHeap<T> {
652652
unsafe { self.sift_up(start, pos) };
653653
}
654654

655+
/// Rebuild assuming data[0..start] is still a proper heap.
656+
fn rebuild_tail(&mut self, start: usize) {
657+
if start == self.len() {
658+
return;
659+
}
660+
661+
let tail_len = self.len() - start;
662+
663+
#[inline(always)]
664+
fn log2_fast(x: usize) -> usize {
665+
(usize::BITS - x.leading_zeros() - 1) as usize
666+
}
667+
668+
// `rebuild` takes O(self.len()) operations
669+
// and about 2 * self.len() comparisons in the worst case
670+
// while repeating `sift_up` takes O(tail_len * log(start)) operations
671+
// and about 1 * tail_len * log_2(start) comparisons in the worst case,
672+
// assuming start >= tail_len. For larger heaps, the crossover point
673+
// no longer follows this reasoning and was determined empirically.
674+
let better_to_rebuild = if start < tail_len {
675+
true
676+
} else if self.len() <= 2048 {
677+
2 * self.len() < tail_len * log2_fast(start)
678+
} else {
679+
2 * self.len() < tail_len * 11
680+
};
681+
682+
if better_to_rebuild {
683+
self.rebuild();
684+
} else {
685+
for i in start..self.len() {
686+
// SAFETY: The index `i` is always less than self.len().
687+
unsafe { self.sift_up(0, i) };
688+
}
689+
}
690+
}
691+
655692
fn rebuild(&mut self) {
656693
let mut n = self.len() / 2;
657694
while n > 0 {
@@ -689,37 +726,11 @@ impl<T: Ord> BinaryHeap<T> {
689726
swap(self, other);
690727
}
691728

692-
if other.is_empty() {
693-
return;
694-
}
695-
696-
#[inline(always)]
697-
fn log2_fast(x: usize) -> usize {
698-
(usize::BITS - x.leading_zeros() - 1) as usize
699-
}
729+
let start = self.data.len();
700730

701-
// `rebuild` takes O(len1 + len2) operations
702-
// and about 2 * (len1 + len2) comparisons in the worst case
703-
// while `extend` takes O(len2 * log(len1)) operations
704-
// and about 1 * len2 * log_2(len1) comparisons in the worst case,
705-
// assuming len1 >= len2. For larger heaps, the crossover point
706-
// no longer follows this reasoning and was determined empirically.
707-
#[inline]
708-
fn better_to_rebuild(len1: usize, len2: usize) -> bool {
709-
let tot_len = len1 + len2;
710-
if tot_len <= 2048 {
711-
2 * tot_len < len2 * log2_fast(len1)
712-
} else {
713-
2 * tot_len < len2 * 11
714-
}
715-
}
731+
self.data.append(&mut other.data);
716732

717-
if better_to_rebuild(self.len(), other.len()) {
718-
self.data.append(&mut other.data);
719-
self.rebuild();
720-
} else {
721-
self.extend(other.drain());
722-
}
733+
self.rebuild_tail(start);
723734
}
724735

725736
/// Returns an iterator which retrieves elements in heap order.
@@ -770,12 +781,22 @@ impl<T: Ord> BinaryHeap<T> {
770781
/// assert_eq!(heap.into_sorted_vec(), [-10, 2, 4])
771782
/// ```
772783
#[unstable(feature = "binary_heap_retain", issue = "71503")]
773-
pub fn retain<F>(&mut self, f: F)
784+
pub fn retain<F>(&mut self, mut f: F)
774785
where
775786
F: FnMut(&T) -> bool,
776787
{
777-
self.data.retain(f);
778-
self.rebuild();
788+
let mut first_removed = self.len();
789+
let mut i = 0;
790+
self.data.retain(|e| {
791+
let keep = f(e);
792+
if !keep && i < first_removed {
793+
first_removed = i;
794+
}
795+
i += 1;
796+
keep
797+
});
798+
// data[0..first_removed] is untouched, so we only need to rebuild the tail:
799+
self.rebuild_tail(first_removed);
779800
}
780801
}
781802

library/alloc/tests/binary_heap.rs

+16-3
Original file line numberDiff line numberDiff line change
@@ -386,10 +386,23 @@ fn assert_covariance() {
386386

387387
#[test]
388388
fn test_retain() {
389-
let mut a = BinaryHeap::from(vec![-10, -5, 1, 2, 4, 13]);
390-
a.retain(|x| x % 2 == 0);
389+
let mut a = BinaryHeap::from(vec![100, 10, 50, 1, 2, 20, 30]);
390+
a.retain(|&x| x != 2);
391391

392-
assert_eq!(a.into_sorted_vec(), [-10, 2, 4])
392+
// Check that 20 moved into 10's place.
393+
assert_eq!(a.clone().into_vec(), [100, 20, 50, 1, 10, 30]);
394+
395+
a.retain(|_| true);
396+
397+
assert_eq!(a.clone().into_vec(), [100, 20, 50, 1, 10, 30]);
398+
399+
a.retain(|&x| x < 50);
400+
401+
assert_eq!(a.clone().into_vec(), [30, 20, 10, 1]);
402+
403+
a.retain(|_| false);
404+
405+
assert!(a.is_empty());
393406
}
394407

395408
// old binaryheap failed this test

0 commit comments

Comments
 (0)