Skip to content

Document BinaryHeap unsafe functions #81706

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 21, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 117 additions & 49 deletions library/alloc/src/collections/binary_heap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,8 @@ impl<T: Ord + fmt::Debug> fmt::Debug for PeekMut<'_, T> {
impl<T: Ord> Drop for PeekMut<'_, T> {
fn drop(&mut self) {
if self.sift {
self.heap.sift_down(0);
// SAFETY: PeekMut is only instantiated for non-empty heaps.
unsafe { self.heap.sift_down(0) };
}
}
}
Expand Down Expand Up @@ -431,7 +432,8 @@ impl<T: Ord> BinaryHeap<T> {
self.data.pop().map(|mut item| {
if !self.is_empty() {
swap(&mut item, &mut self.data[0]);
self.sift_down_to_bottom(0);
// SAFETY: !self.is_empty() means that self.len() > 0
unsafe { self.sift_down_to_bottom(0) };
}
item
})
Expand Down Expand Up @@ -473,7 +475,9 @@ impl<T: Ord> BinaryHeap<T> {
pub fn push(&mut self, item: T) {
let old_len = self.len();
self.data.push(item);
self.sift_up(0, old_len);
// SAFETY: Since we pushed a new item it means that
// old_len = self.len() - 1 < self.len()
unsafe { self.sift_up(0, old_len) };
}

/// Consumes the `BinaryHeap` and returns a vector in sorted
Expand Down Expand Up @@ -506,7 +510,10 @@ impl<T: Ord> BinaryHeap<T> {
let ptr = self.data.as_mut_ptr();
ptr::swap(ptr, ptr.add(end));
}
self.sift_down_range(0, end);
// SAFETY: `end` goes from `self.len() - 1` to 1 (both included) so:
// 0 < 1 <= end <= self.len() - 1 < self.len()
// Which means 0 < end and end < self.len().
unsafe { self.sift_down_range(0, end) };
}
self.into_vec()
}
Expand All @@ -519,78 +526,139 @@ impl<T: Ord> BinaryHeap<T> {
// the hole is filled back at the end of its scope, even on panic.
// Using a hole reduces the constant factor compared to using swaps,
// which involves twice as many moves.
fn sift_up(&mut self, start: usize, pos: usize) -> usize {
unsafe {
// Take out the value at `pos` and create a hole.
let mut hole = Hole::new(&mut self.data, pos);

while hole.pos() > start {
let parent = (hole.pos() - 1) / 2;
if hole.element() <= hole.get(parent) {
break;
}
hole.move_to(parent);

/// # Safety
///
/// The caller must guarantee that `pos < self.len()`.
unsafe fn sift_up(&mut self, start: usize, pos: usize) -> usize {
// Take out the value at `pos` and create a hole.
// SAFETY: The caller guarantees that pos < self.len()
let mut hole = unsafe { Hole::new(&mut self.data, pos) };

while hole.pos() > start {
let parent = (hole.pos() - 1) / 2;

// SAFETY: hole.pos() > start >= 0, which means hole.pos() > 0
// and so hole.pos() - 1 can't underflow.
// This guarantees that parent < hole.pos() so
// it's a valid index and also != hole.pos().
if hole.element() <= unsafe { hole.get(parent) } {
break;
}
hole.pos()

// SAFETY: Same as above
unsafe { hole.move_to(parent) };
}

hole.pos()
}

/// Take an element at `pos` and move it down the heap,
/// while its children are larger.
fn sift_down_range(&mut self, pos: usize, end: usize) {
unsafe {
let mut hole = Hole::new(&mut self.data, pos);
let mut child = 2 * pos + 1;
while child < end - 1 {
// compare with the greater of the two children
child += (hole.get(child) <= hole.get(child + 1)) as usize;
// if we are already in order, stop.
if hole.element() >= hole.get(child) {
return;
}
hole.move_to(child);
child = 2 * hole.pos() + 1;
}
if child == end - 1 && hole.element() < hole.get(child) {
hole.move_to(child);
///
/// # Safety
///
/// The caller must guarantee that `pos < end <= self.len()`.
unsafe fn sift_down_range(&mut self, pos: usize, end: usize) {
// SAFETY: The caller guarantees that pos < end <= self.len().
let mut hole = unsafe { Hole::new(&mut self.data, pos) };
let mut child = 2 * hole.pos() + 1;

// Loop invariant: child == 2 * hole.pos() + 1.
while child < end - 1 {
// compare with the greater of the two children
// SAFETY: child < end - 1 < self.len() and
// child + 1 < end <= self.len(), so they're valid indexes.
// child == 2 * hole.pos() + 1 != hole.pos() and
// child + 1 == 2 * hole.pos() + 2 != hole.pos().
// FIXME: 2 * hole.pos() + 1 or 2 * hole.pos() + 2 could overflow
// if T is a ZST
child += unsafe { hole.get(child) <= hole.get(child + 1) } as usize;

// if we are already in order, stop.
// SAFETY: child is now either the old child or the old child+1
// We already proven that both are < self.len() and != hole.pos()
if hole.element() >= unsafe { hole.get(child) } {
return;
}

// SAFETY: same as above.
unsafe { hole.move_to(child) };
child = 2 * hole.pos() + 1;
}

// SAFETY: && short circuit, which means that in the
// second condition it's already true that child == end - 1 < self.len().
if child == end - 1 && hole.element() < unsafe { hole.get(child) } {
// SAFETY: child is already proven to be a valid index and
// child == 2 * hole.pos() + 1 != hole.pos().
unsafe { hole.move_to(child) };
}
}

fn sift_down(&mut self, pos: usize) {
/// # Safety
///
/// The caller must guarantee that `pos < self.len()`.
unsafe fn sift_down(&mut self, pos: usize) {
let len = self.len();
self.sift_down_range(pos, len);
// SAFETY: pos < len is guaranteed by the caller and
// obviously len = self.len() <= self.len().
unsafe { self.sift_down_range(pos, len) };
}

/// Take an element at `pos` and move it all the way down the heap,
/// then sift it up to its position.
///
/// Note: This is faster when the element is known to be large / should
/// be closer to the bottom.
fn sift_down_to_bottom(&mut self, mut pos: usize) {
///
/// # Safety
///
/// The caller must guarantee that `pos < self.len()`.
unsafe fn sift_down_to_bottom(&mut self, mut pos: usize) {
let end = self.len();
let start = pos;
unsafe {
let mut hole = Hole::new(&mut self.data, pos);
let mut child = 2 * pos + 1;
while child < end - 1 {
child += (hole.get(child) <= hole.get(child + 1)) as usize;
hole.move_to(child);
child = 2 * hole.pos() + 1;
}
if child == end - 1 {
hole.move_to(child);
}
pos = hole.pos;

// SAFETY: The caller guarantees that pos < self.len().
let mut hole = unsafe { Hole::new(&mut self.data, pos) };
let mut child = 2 * hole.pos() + 1;

// Loop invariant: child == 2 * hole.pos() + 1.
while child < end - 1 {
// SAFETY: child < end - 1 < self.len() and
// child + 1 < end <= self.len(), so they're valid indexes.
// child == 2 * hole.pos() + 1 != hole.pos() and
// child + 1 == 2 * hole.pos() + 2 != hole.pos().
// FIXME: 2 * hole.pos() + 1 or 2 * hole.pos() + 2 could overflow
// if T is a ZST
child += unsafe { hole.get(child) <= hole.get(child + 1) } as usize;

// SAFETY: Same as above
unsafe { hole.move_to(child) };
child = 2 * hole.pos() + 1;
}
self.sift_up(start, pos);

if child == end - 1 {
// SAFETY: child == end - 1 < self.len(), so it's a valid index
// and child == 2 * hole.pos() + 1 != hole.pos().
unsafe { hole.move_to(child) };
}
pos = hole.pos();
drop(hole);

// SAFETY: pos is the position in the hole and was already proven
// to be a valid index.
unsafe { self.sift_up(start, pos) };
}

fn rebuild(&mut self) {
let mut n = self.len() / 2;
while n > 0 {
n -= 1;
self.sift_down(n);
// SAFETY: n starts from self.len() / 2 and goes down to 0.
// The only case when !(n < self.len()) is if
// self.len() == 0, but it's ruled out by the loop condition.
unsafe { self.sift_down(n) };
}
}

Expand Down