Skip to content

Commit c41a5cc

Browse files
author
Ulrik Sverdrup
committed
collections: Make BinaryHeap panic safe in sift_up / sift_down
Use a struct called Hole that keeps track of an invalid location in the vector and fills the hole on drop. I include a run-pass test that the current BinaryHeap fails, and the new one passes. Fixes #25842
1 parent 541fe5f commit c41a5cc

File tree

2 files changed

+201
-27
lines changed

2 files changed

+201
-27
lines changed

src/libcollections/binary_heap.rs

+93-27
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@
153153
use core::prelude::*;
154154

155155
use core::iter::{FromIterator};
156-
use core::mem::{zeroed, replace, swap};
156+
use core::mem::swap;
157157
use core::ptr;
158158

159159
use slice;
@@ -484,44 +484,43 @@ impl<T: Ord> BinaryHeap<T> {
484484

485485
// The implementations of sift_up and sift_down use unsafe blocks in
486486
// order to move an element out of the vector (leaving behind a
487-
// zeroed element), shift along the others and move it back into the
488-
// vector over the junk element. This reduces the constant factor
489-
// compared to using swaps, which involves twice as many moves.
490-
fn sift_up(&mut self, start: usize, mut pos: usize) {
487+
// hole), shift along the others and move the removed element back into the
488+
// vector at the final location of the hole.
489+
// The `Hole` type is used to represent this, and make sure
490+
// the hole is filled back at the end of its scope, even on panic.
491+
// Using a hole reduces the constant factor compared to using swaps,
492+
// which involves twice as many moves.
493+
fn sift_up(&mut self, start: usize, pos: usize) {
491494
unsafe {
492-
let new = replace(&mut self.data[pos], zeroed());
495+
// Take out the value at `pos` and create a hole.
496+
let mut hole = Hole::new(&mut self.data, pos);
493497

494-
while pos > start {
495-
let parent = (pos - 1) >> 1;
496-
497-
if new <= self.data[parent] { break; }
498-
499-
let x = replace(&mut self.data[parent], zeroed());
500-
ptr::write(&mut self.data[pos], x);
501-
pos = parent;
498+
while hole.pos() > start {
499+
let parent = (hole.pos() - 1) / 2;
500+
if hole.removed() <= hole.get(parent) { break }
501+
hole.move_to(parent);
502502
}
503-
ptr::write(&mut self.data[pos], new);
504503
}
505504
}
506505

507506
fn sift_down_range(&mut self, mut pos: usize, end: usize) {
508507
unsafe {
509508
let start = pos;
510-
let new = replace(&mut self.data[pos], zeroed());
511-
512-
let mut child = 2 * pos + 1;
513-
while child < end {
514-
let right = child + 1;
515-
if right < end && !(self.data[child] > self.data[right]) {
516-
child = right;
509+
{
510+
let mut hole = Hole::new(&mut self.data, pos);
511+
let mut child = 2 * pos + 1;
512+
while child < end {
513+
let right = child + 1;
514+
if right < end && !(hole.get(child) > hole.get(right)) {
515+
child = right;
516+
}
517+
hole.move_to(child);
518+
child = 2 * hole.pos() + 1;
517519
}
518-
let x = replace(&mut self.data[child], zeroed());
519-
ptr::write(&mut self.data[pos], x);
520-
pos = child;
521-
child = 2 * pos + 1;
520+
521+
pos = hole.pos;
522522
}
523523

524-
ptr::write(&mut self.data[pos], new);
525524
self.sift_up(start, pos);
526525
}
527526
}
@@ -554,6 +553,73 @@ impl<T: Ord> BinaryHeap<T> {
554553
pub fn clear(&mut self) { self.drain(); }
555554
}
556555

556+
/// Hole represents a hole in a slice i.e. an index without valid value
557+
/// (because it was moved from or duplicated).
558+
/// In drop, `Hole` will restore the slice by filling the hole
559+
/// position with the value that was originally removed.
560+
struct Hole<'a, T: 'a> {
561+
data: &'a mut [T],
562+
/// `elt` is always `Some` from new until drop.
563+
elt: Option<T>,
564+
pos: usize,
565+
}
566+
567+
impl<'a, T> Hole<'a, T> {
568+
/// Create a new Hole at index `pos`.
569+
fn new(data: &'a mut [T], pos: usize) -> Self {
570+
unsafe {
571+
let elt = ptr::read(&data[pos]);
572+
Hole {
573+
data: data,
574+
elt: Some(elt),
575+
pos: pos,
576+
}
577+
}
578+
}
579+
580+
#[inline(always)]
581+
fn pos(&self) -> usize { self.pos }
582+
583+
/// Return a reference to the element removed
584+
#[inline(always)]
585+
fn removed(&self) -> &T {
586+
self.elt.as_ref().unwrap()
587+
}
588+
589+
/// Return a reference to the element at `index`.
590+
///
591+
/// Panics if the index is out of bounds.
592+
///
593+
/// Unsafe because index must not equal pos.
594+
#[inline(always)]
595+
unsafe fn get(&self, index: usize) -> &T {
596+
debug_assert!(index != self.pos);
597+
&self.data[index]
598+
}
599+
600+
/// Move hole to new location
601+
///
602+
/// Unsafe because index must not equal pos.
603+
#[inline(always)]
604+
unsafe fn move_to(&mut self, index: usize) {
605+
debug_assert!(index != self.pos);
606+
let index_ptr: *const _ = &self.data[index];
607+
let hole_ptr = &mut self.data[self.pos];
608+
ptr::copy_nonoverlapping(index_ptr, hole_ptr, 1);
609+
self.pos = index;
610+
}
611+
}
612+
613+
impl<'a, T> Drop for Hole<'a, T> {
614+
fn drop(&mut self) {
615+
// fill the hole again
616+
unsafe {
617+
let pos = self.pos;
618+
ptr::write(&mut self.data[pos], self.elt.take().unwrap());
619+
}
620+
}
621+
}
622+
557623
/// `BinaryHeap` iterator.
558624
#[stable(feature = "rust1", since = "1.0.0")]
559625
pub struct Iter <'a, T: 'a> {
+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
// Copyright 2015 The Rust Project Developers. See the COPYRIGHT
2+
// file at the top-level directory of this distribution and at
3+
// http://rust-lang.org/COPYRIGHT.
4+
//
5+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6+
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7+
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8+
// option. This file may not be copied, modified, or distributed
9+
// except according to those terms.
10+
11+
#![feature(std_misc, collections, catch_panic, rand)]
12+
13+
use std::__rand::{thread_rng, Rng};
14+
use std::thread;
15+
16+
use std::collections::BinaryHeap;
17+
use std::cmp;
18+
use std::sync::Arc;
19+
use std::sync::Mutex;
20+
use std::sync::atomic::{AtomicUsize, ATOMIC_USIZE_INIT, Ordering};
21+
22+
static DROP_COUNTER: AtomicUsize = ATOMIC_USIZE_INIT;
23+
24+
// old binaryheap failed this test
25+
//
26+
// Integrity means that all elements are present after a comparison panics,
27+
// even if the order may not be correct.
28+
//
29+
// Destructors must be called exactly once per element.
30+
fn test_integrity() {
31+
#[derive(Eq, PartialEq, Ord, Clone, Debug)]
32+
struct PanicOrd<T>(T, bool);
33+
34+
impl<T> Drop for PanicOrd<T> {
35+
fn drop(&mut self) {
36+
// update global drop count
37+
DROP_COUNTER.fetch_add(1, Ordering::SeqCst);
38+
}
39+
}
40+
41+
impl<T: PartialOrd> PartialOrd for PanicOrd<T> {
42+
fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
43+
if self.1 || other.1 {
44+
panic!("Panicking comparison");
45+
}
46+
self.0.partial_cmp(&other.0)
47+
}
48+
}
49+
let mut rng = thread_rng();
50+
const DATASZ: usize = 32;
51+
const NTEST: usize = 10;
52+
53+
// don't use 0 in the data -- we want to catch the zeroed-out case.
54+
let data = (1..DATASZ + 1).collect::<Vec<_>>();
55+
56+
// since it's a fuzzy test, run several tries.
57+
for _ in 0..NTEST {
58+
for i in 1..DATASZ + 1 {
59+
DROP_COUNTER.store(0, Ordering::SeqCst);
60+
61+
let mut panic_ords: Vec<_> = data.iter()
62+
.filter(|&&x| x != i)
63+
.map(|&x| PanicOrd(x, false))
64+
.collect();
65+
let panic_item = PanicOrd(i, true);
66+
67+
// heapify the sane items
68+
rng.shuffle(&mut panic_ords);
69+
let heap = Arc::new(Mutex::new(BinaryHeap::from_vec(panic_ords)));
70+
let inner_data;
71+
72+
{
73+
let heap_ref = heap.clone();
74+
75+
76+
// push the panicking item to the heap and catch the panic
77+
let thread_result = thread::catch_panic(move || {
78+
heap.lock().unwrap().push(panic_item);
79+
});
80+
assert!(thread_result.is_err());
81+
82+
// Assert no elements were dropped
83+
let drops = DROP_COUNTER.load(Ordering::SeqCst);
84+
//assert!(drops == 0, "Must not drop items. drops={}", drops);
85+
86+
{
87+
// now fetch the binary heap's data vector
88+
let mutex_guard = match heap_ref.lock() {
89+
Ok(x) => x,
90+
Err(poison) => poison.into_inner(),
91+
};
92+
inner_data = mutex_guard.clone().into_vec();
93+
}
94+
}
95+
let drops = DROP_COUNTER.load(Ordering::SeqCst);
96+
assert_eq!(drops, DATASZ);
97+
98+
let mut data_sorted = inner_data.into_iter().map(|p| p.0).collect::<Vec<_>>();
99+
data_sorted.sort();
100+
assert_eq!(data_sorted, data);
101+
}
102+
}
103+
}
104+
105+
fn main() {
106+
test_integrity();
107+
}
108+

0 commit comments

Comments
 (0)