Skip to content

Commit 5f44f5b

Browse files
author
Ulrik Sverdrup
committed
collections: Make BinaryHeap panic safe in sift_up / sift_down
Use a struct called Hole that keeps track of an invalid location in the vector and fills the hole on drop. I include a run-pass test that the current BinaryHeap fails, and the new one passes. Fixes #25842
1 parent 541fe5f commit 5f44f5b

File tree

2 files changed

+161
-27
lines changed

2 files changed

+161
-27
lines changed

src/libcollections/binary_heap.rs

+90-27
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@
153153
use core::prelude::*;
154154

155155
use core::iter::{FromIterator};
156-
use core::mem::{zeroed, replace, swap};
156+
use core::mem::swap;
157157
use core::ptr;
158158

159159
use slice;
@@ -484,44 +484,43 @@ impl<T: Ord> BinaryHeap<T> {
484484

485485
// The implementations of sift_up and sift_down use unsafe blocks in
486486
// order to move an element out of the vector (leaving behind a
487-
// zeroed element), shift along the others and move it back into the
488-
// vector over the junk element. This reduces the constant factor
489-
// compared to using swaps, which involves twice as many moves.
490-
fn sift_up(&mut self, start: usize, mut pos: usize) {
487+
// hole), shift along the others and move the removed element back into the
488+
// vector at the final location of the hole.
489+
// The `Hole` type is used to represent this, and make sure
490+
// the hole is filled back at the end of its scope, even on panic.
491+
// Using a hole reduces the constant factor compared to using swaps,
492+
// which involves twice as many moves.
493+
fn sift_up(&mut self, start: usize, pos: usize) {
491494
unsafe {
492-
let new = replace(&mut self.data[pos], zeroed());
495+
// Take out the value at `pos` and create a hole.
496+
let mut hole = Hole::new(&mut self.data, pos);
493497

494-
while pos > start {
495-
let parent = (pos - 1) >> 1;
496-
497-
if new <= self.data[parent] { break; }
498-
499-
let x = replace(&mut self.data[parent], zeroed());
500-
ptr::write(&mut self.data[pos], x);
501-
pos = parent;
498+
while hole.pos() > start {
499+
let parent = (hole.pos() - 1) >> 1;
500+
if hole.removed() <= hole.get(parent) { break }
501+
hole.move_to(parent);
502502
}
503-
ptr::write(&mut self.data[pos], new);
504503
}
505504
}
506505

507506
fn sift_down_range(&mut self, mut pos: usize, end: usize) {
508507
unsafe {
509508
let start = pos;
510-
let new = replace(&mut self.data[pos], zeroed());
511-
512-
let mut child = 2 * pos + 1;
513-
while child < end {
514-
let right = child + 1;
515-
if right < end && !(self.data[child] > self.data[right]) {
516-
child = right;
509+
{
510+
let mut hole = Hole::new(&mut self.data, pos);
511+
let mut child = 2 * pos + 1;
512+
while child < end {
513+
let right = child + 1;
514+
if right < end && !(hole.get(child) > hole.get(right)) {
515+
child = right;
516+
}
517+
hole.move_to(child);
518+
child = 2 * hole.pos() + 1;
517519
}
518-
let x = replace(&mut self.data[child], zeroed());
519-
ptr::write(&mut self.data[pos], x);
520-
pos = child;
521-
child = 2 * pos + 1;
520+
521+
pos = hole.pos;
522522
}
523523

524-
ptr::write(&mut self.data[pos], new);
525524
self.sift_up(start, pos);
526525
}
527526
}
@@ -554,6 +553,70 @@ impl<T: Ord> BinaryHeap<T> {
554553
pub fn clear(&mut self) { self.drain(); }
555554
}
556555

556+
/// Hole represents a hole in a slice i.e. an index without valid value
557+
/// (because it was moved from or duplicated).
558+
/// In drop, `Hole` will restore the slice by filling the hole
559+
/// position with the value that was originally removed.
560+
struct Hole<'a, T: 'a> {
561+
data: &'a mut [T],
562+
elt: Option<T>,
563+
pos: usize,
564+
}
565+
566+
impl<'a, T> Hole<'a, T> {
567+
/// Create a new Hole at index `pos`.
568+
pub fn new(data: &'a mut [T], pos: usize) -> Self {
569+
unsafe {
570+
let elt = ptr::read(&data[pos]);
571+
Hole {
572+
data: data,
573+
elt: Some(elt),
574+
pos: pos,
575+
}
576+
}
577+
}
578+
579+
#[inline(always)]
580+
pub fn pos(&self) -> usize { self.pos }
581+
582+
/// Return a reference to the element removed
583+
#[inline(always)]
584+
pub fn removed(&self) -> &T {
585+
self.elt.as_ref().unwrap()
586+
}
587+
588+
/// Return a reference to the element at `index`.
589+
///
590+
/// Panics if the index is out of bounds.
591+
///
592+
/// Unsafe because index must not equal pos.
593+
#[inline(always)]
594+
pub unsafe fn get(&self, index: usize) -> &T {
595+
debug_assert!(index != self.pos);
596+
&self.data[index]
597+
}
598+
599+
/// Move hole to new location
600+
#[inline(always)]
601+
pub unsafe fn move_to(&mut self, index: usize) {
602+
debug_assert!(index != self.pos);
603+
let old_pos = self.pos;
604+
let x = ptr::read(&mut self.data[index]);
605+
ptr::write(&mut self.data[old_pos], x);
606+
self.pos = index;
607+
}
608+
}
609+
610+
impl<'a, T> Drop for Hole<'a, T> {
611+
fn drop(&mut self) {
612+
// fill the hole again
613+
unsafe {
614+
let pos = self.pos;
615+
ptr::write(&mut self.data[pos], self.elt.take().unwrap());
616+
}
617+
}
618+
}
619+
557620
/// `BinaryHeap` iterator.
558621
#[stable(feature = "rust1", since = "1.0.0")]
559622
pub struct Iter <'a, T: 'a> {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
2+
#![feature(std_misc, collections, catch_panic, rand)]
3+
4+
use std::__rand::{thread_rng, Rng};
5+
use std::thread;
6+
7+
use std::collections::BinaryHeap;
8+
use std::cmp::Ordering;
9+
use std::sync::Arc;
10+
use std::sync::Mutex;
11+
12+
// old binaryheap failed this test
13+
//
14+
// Integrity means that all elements are present after a comparison panics,
15+
// even if the order may not be correct.
16+
fn test_integrity() {
17+
#[derive(Eq, PartialEq, Ord, Clone, Debug)]
18+
struct PanicOrd<T>(T, bool);
19+
20+
impl<T: PartialOrd> PartialOrd for PanicOrd<T> {
21+
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
22+
if self.1 || other.1 {
23+
panic!("Panicking comparison");
24+
}
25+
self.0.partial_cmp(&other.0)
26+
}
27+
}
28+
let mut rng = thread_rng();
29+
const DATASZ: usize = 32;
30+
const NTEST: usize = 10;
31+
32+
// don't use 0 in the data -- we want to catch the zeroed-out case.
33+
let data = (1..DATASZ + 1).collect::<Vec<_>>();
34+
35+
// since it's a fuzzy test, run several tries.
36+
for _ in 0..NTEST {
37+
for i in 1..DATASZ + 1 {
38+
let mut panic_ords: Vec<_> = data.iter()
39+
.filter(|&&x| x != i)
40+
.map(|&x| PanicOrd(x, false))
41+
.collect();
42+
let panic_item = PanicOrd(i, true);
43+
44+
// heapify the sane items
45+
rng.shuffle(&mut panic_ords);
46+
let heap = Arc::new(Mutex::new(BinaryHeap::from_vec(panic_ords)));
47+
let heap_ref = heap.clone();
48+
49+
// push the panicking item to the heap and catch the panic
50+
let thread_result = thread::catch_panic(move || {
51+
heap.lock().unwrap().push(panic_item);
52+
});
53+
assert!(thread_result.is_err());
54+
55+
// now fetch the binary heap again
56+
let mutex_guard = match heap_ref.lock() {
57+
Ok(x) => x,
58+
Err(poison) => poison.into_inner(),
59+
};
60+
let inner_data = mutex_guard.clone().into_vec();
61+
let mut data_sorted = inner_data.into_iter().map(|p| p.0).collect::<Vec<_>>();
62+
data_sorted.sort();
63+
assert_eq!(data_sorted, data);
64+
}
65+
}
66+
}
67+
68+
fn main() {
69+
test_integrity();
70+
}
71+

0 commit comments

Comments
 (0)