diff --git a/src/liballoc/tests/vec.rs b/src/liballoc/tests/vec.rs index 5ddac673c9ff1..07ee37412680d 100644 --- a/src/liballoc/tests/vec.rs +++ b/src/liballoc/tests/vec.rs @@ -945,6 +945,115 @@ fn drain_filter_complex() { } } +#[test] +#[cfg(not(miri))] // Miri does not support catching panics +fn drain_filter_consumed_panic() { + use std::rc::Rc; + use std::sync::Mutex; + + struct Check { + index: usize, + drop_counts: Rc>>, + }; + + impl Drop for Check { + fn drop(&mut self) { + self.drop_counts.lock().unwrap()[self.index] += 1; + println!("drop: {}", self.index); + } + } + + let check_count = 10; + let drop_counts = Rc::new(Mutex::new(vec![0_usize; check_count])); + let mut data: Vec = (0..check_count) + .map(|index| Check { index, drop_counts: Rc::clone(&drop_counts) }) + .collect(); + + let _ = std::panic::catch_unwind(move || { + let filter = |c: &mut Check| { + if c.index == 2 { + panic!("panic at index: {}", c.index); + } + // Verify that if the filter could panic again on another element + // that it would not cause a double panic and all elements of the + // vec would still be dropped exactly once. + if c.index == 4 { + panic!("panic at index: {}", c.index); + } + c.index < 6 + }; + let drain = data.drain_filter(filter); + + // NOTE: The DrainFilter is explictly consumed + drain.for_each(drop); + }); + + let drop_counts = drop_counts.lock().unwrap(); + assert_eq!(check_count, drop_counts.len()); + + for (index, count) in drop_counts.iter().cloned().enumerate() { + assert_eq!(1, count, "unexpected drop count at index: {} (count: {})", index, count); + } +} + +#[test] +#[cfg(not(miri))] // Miri does not support catching panics +fn drain_filter_unconsumed_panic() { + use std::rc::Rc; + use std::sync::Mutex; + + struct Check { + index: usize, + drop_counts: Rc>>, + }; + + impl Drop for Check { + fn drop(&mut self) { + self.drop_counts.lock().unwrap()[self.index] += 1; + println!("drop: {}", self.index); + } + } + + let check_count = 10; + let drop_counts = Rc::new(Mutex::new(vec![0_usize; check_count])); + let mut data: Vec = (0..check_count) + .map(|index| Check { index, drop_counts: Rc::clone(&drop_counts) }) + .collect(); + + let _ = std::panic::catch_unwind(move || { + let filter = |c: &mut Check| { + if c.index == 2 { + panic!("panic at index: {}", c.index); + } + // Verify that if the filter could panic again on another element + // that it would not cause a double panic and all elements of the + // vec would still be dropped exactly once. + if c.index == 4 { + panic!("panic at index: {}", c.index); + } + c.index < 6 + }; + let _drain = data.drain_filter(filter); + + // NOTE: The DrainFilter is dropped without being consumed + }); + + let drop_counts = drop_counts.lock().unwrap(); + assert_eq!(check_count, drop_counts.len()); + + for (index, count) in drop_counts.iter().cloned().enumerate() { + assert_eq!(1, count, "unexpected drop count at index: {} (count: {})", index, count); + } +} + +#[test] +fn drain_filter_unconsumed() { + let mut vec = vec![1, 2, 3, 4]; + let drain = vec.drain_filter(|&mut x| x % 2 != 0); + drop(drain); + assert_eq!(vec, [2, 4]); +} + #[test] fn test_reserve_exact() { // This is all the same as test_reserve diff --git a/src/liballoc/vec.rs b/src/liballoc/vec.rs index 5cb91395b7bf7..8939248adab0d 100644 --- a/src/liballoc/vec.rs +++ b/src/liballoc/vec.rs @@ -2120,6 +2120,7 @@ impl Vec { del: 0, old_len, pred: filter, + panic_flag: false, } } } @@ -2747,10 +2748,20 @@ pub struct DrainFilter<'a, T, F> where F: FnMut(&mut T) -> bool, { vec: &'a mut Vec, + /// The index of the item that will be inspected by the next call to `next`. idx: usize, + /// The number of items that have been drained (removed) thus far. del: usize, + /// The original length of `vec` prior to draining. old_len: usize, + /// The filter test predicate. pred: F, + /// A flag that indicates a panic has occured in the filter test prodicate. + /// This is used as a hint in the drop implmentation to prevent consumption + /// of the remainder of the `DrainFilter`. Any unprocessed items will be + /// backshifted in the `vec`, but no further items will be dropped or + /// tested by the filter predicate. + panic_flag: bool, } #[unstable(feature = "drain_filter", reason = "recently added", issue = "43244")] @@ -2761,20 +2772,23 @@ impl Iterator for DrainFilter<'_, T, F> fn next(&mut self) -> Option { unsafe { - while self.idx != self.old_len { + while self.idx < self.old_len { let i = self.idx; - self.idx += 1; let v = slice::from_raw_parts_mut(self.vec.as_mut_ptr(), self.old_len); - if (self.pred)(&mut v[i]) { + self.panic_flag = true; + let drained = (self.pred)(&mut v[i]); + self.panic_flag = false; + // Update the index *after* the predicate is called. If the index + // is updated prior and the predicate panics, the element at this + // index would be leaked. + self.idx += 1; + if drained { self.del += 1; return Some(ptr::read(&v[i])); } else if self.del > 0 { let del = self.del; let src: *const T = &v[i]; let dst: *mut T = &mut v[i - del]; - // This is safe because self.vec has length 0 - // thus its elements will not have Drop::drop - // called on them in the event of a panic. ptr::copy_nonoverlapping(src, dst, 1); } } @@ -2792,9 +2806,46 @@ impl Drop for DrainFilter<'_, T, F> where F: FnMut(&mut T) -> bool, { fn drop(&mut self) { - self.for_each(drop); - unsafe { - self.vec.set_len(self.old_len - self.del); + struct BackshiftOnDrop<'a, 'b, T, F> + where + F: FnMut(&mut T) -> bool, + { + drain: &'b mut DrainFilter<'a, T, F>, + } + + impl<'a, 'b, T, F> Drop for BackshiftOnDrop<'a, 'b, T, F> + where + F: FnMut(&mut T) -> bool + { + fn drop(&mut self) { + unsafe { + if self.drain.idx < self.drain.old_len && self.drain.del > 0 { + // This is a pretty messed up state, and there isn't really an + // obviously right thing to do. We don't want to keep trying + // to execute `pred`, so we just backshift all the unprocessed + // elements and tell the vec that they still exist. The backshift + // is required to prevent a double-drop of the last successfully + // drained item prior to a panic in the predicate. + let ptr = self.drain.vec.as_mut_ptr(); + let src = ptr.add(self.drain.idx); + let dst = src.sub(self.drain.del); + let tail_len = self.drain.old_len - self.drain.idx; + src.copy_to(dst, tail_len); + } + self.drain.vec.set_len(self.drain.old_len - self.drain.del); + } + } + } + + let backshift = BackshiftOnDrop { + drain: self + }; + + // Attempt to consume any remaining elements if the filter predicate + // has not yet panicked. We'll backshift any remaining elements + // whether we've already panicked or if the consumption here panics. + if !backshift.drain.panic_flag { + backshift.drain.for_each(drop); } } }