Skip to content

Commit fab9b1d

Browse files
authored
Rollup merge of #74677 - ssomers:btree_cleanup_2, r=Amanieu
Remove needless unsafety from BTreeMap::drain_filter Remove one piece of unsafe code in the iteration over the iterator returned by BTreeMap::drain_filter. - Changes an explicitly unspecified part of the API: when the user-supplied predicate (or some of BTreeMap's code) panicked, and the caller tries to use the iterator again, we no longer offer the same key/value pair to the predicate again but pretend the iterator has finished. Note that Miri does not find UB in the test case added here with the unsafe code (or without). - Makes the code a little easier on the eyes. - Makes the code a little harder on the CPU: ``` benchcmp c0 c2 --threshold 3 name c0 ns/iter c2 ns/iter diff ns/iter diff % speedup btree::set::clone_100_and_drain_all 2,794 2,900 106 3.79% x 0.96 btree::set::clone_100_and_drain_half 2,604 2,964 360 13.82% x 0.88 btree::set::clone_10k_and_drain_half 287,770 322,755 34,985 12.16% x 0.89 ``` r? @Amanieu
2 parents 7f2bb29 + facc46f commit fab9b1d

File tree

2 files changed

+47
-20
lines changed

2 files changed

+47
-20
lines changed

src/liballoc/collections/btree/map.rs

+1-8
Original file line numberDiff line numberDiff line change
@@ -1672,19 +1672,12 @@ impl<'a, K: 'a, V: 'a> DrainFilterInner<'a, K, V> {
16721672
edge.reborrow().next_kv().ok().map(|kv| kv.into_kv())
16731673
}
16741674

1675-
unsafe fn next_kv(
1676-
&mut self,
1677-
) -> Option<Handle<NodeRef<marker::Mut<'a>, K, V, marker::LeafOrInternal>, marker::KV>> {
1678-
let edge = self.cur_leaf_edge.as_ref()?;
1679-
unsafe { ptr::read(edge).next_kv().ok() }
1680-
}
1681-
16821675
/// Implementation of a typical `DrainFilter::next` method, given the predicate.
16831676
pub(super) fn next<F>(&mut self, pred: &mut F) -> Option<(K, V)>
16841677
where
16851678
F: FnMut(&K, &mut V) -> bool,
16861679
{
1687-
while let Some(mut kv) = unsafe { self.next_kv() } {
1680+
while let Ok(mut kv) = self.cur_leaf_edge.take()?.next_kv() {
16881681
let (k, v) = kv.kv_mut();
16891682
if pred(k, v) {
16901683
*self.length -= 1;

src/liballoc/tests/btree/map.rs

+46-12
Original file line numberDiff line numberDiff line change
@@ -887,18 +887,16 @@ mod test_drain_filter {
887887
}
888888
}
889889

890-
let mut map = BTreeMap::new();
891-
map.insert(0, D);
892-
map.insert(4, D);
893-
map.insert(8, D);
890+
// Keys are multiples of 4, so that each key is counted by a hexadecimal digit.
891+
let mut map = (0..3).map(|i| (i * 4, D)).collect::<BTreeMap<_, _>>();
894892

895893
catch_unwind(move || {
896894
drop(map.drain_filter(|i, _| {
897895
PREDS.fetch_add(1usize << i, Ordering::SeqCst);
898896
true
899897
}))
900898
})
901-
.ok();
899+
.unwrap_err();
902900

903901
assert_eq!(PREDS.load(Ordering::SeqCst), 0x011);
904902
assert_eq!(DROPS.load(Ordering::SeqCst), 3);
@@ -916,10 +914,8 @@ mod test_drain_filter {
916914
}
917915
}
918916

919-
let mut map = BTreeMap::new();
920-
map.insert(0, D);
921-
map.insert(4, D);
922-
map.insert(8, D);
917+
// Keys are multiples of 4, so that each key is counted by a hexadecimal digit.
918+
let mut map = (0..3).map(|i| (i * 4, D)).collect::<BTreeMap<_, _>>();
923919

924920
catch_unwind(AssertUnwindSafe(|| {
925921
drop(map.drain_filter(|i, _| {
@@ -930,7 +926,45 @@ mod test_drain_filter {
930926
}
931927
}))
932928
}))
933-
.ok();
929+
.unwrap_err();
930+
931+
assert_eq!(PREDS.load(Ordering::SeqCst), 0x011);
932+
assert_eq!(DROPS.load(Ordering::SeqCst), 1);
933+
assert_eq!(map.len(), 2);
934+
assert_eq!(map.first_entry().unwrap().key(), &4);
935+
assert_eq!(map.last_entry().unwrap().key(), &8);
936+
}
937+
938+
// Same as above, but attempt to use the iterator again after the panic in the predicate
939+
#[test]
940+
fn pred_panic_reuse() {
941+
static PREDS: AtomicUsize = AtomicUsize::new(0);
942+
static DROPS: AtomicUsize = AtomicUsize::new(0);
943+
944+
struct D;
945+
impl Drop for D {
946+
fn drop(&mut self) {
947+
DROPS.fetch_add(1, Ordering::SeqCst);
948+
}
949+
}
950+
951+
// Keys are multiples of 4, so that each key is counted by a hexadecimal digit.
952+
let mut map = (0..3).map(|i| (i * 4, D)).collect::<BTreeMap<_, _>>();
953+
954+
{
955+
let mut it = map.drain_filter(|i, _| {
956+
PREDS.fetch_add(1usize << i, Ordering::SeqCst);
957+
match i {
958+
0 => true,
959+
_ => panic!(),
960+
}
961+
});
962+
catch_unwind(AssertUnwindSafe(|| while it.next().is_some() {})).unwrap_err();
963+
// Iterator behaviour after a panic is explicitly unspecified,
964+
// so this is just the current implementation:
965+
let result = catch_unwind(AssertUnwindSafe(|| it.next()));
966+
assert!(matches!(result, Ok(None)));
967+
}
934968

935969
assert_eq!(PREDS.load(Ordering::SeqCst), 0x011);
936970
assert_eq!(DROPS.load(Ordering::SeqCst), 1);
@@ -1399,7 +1433,7 @@ fn test_into_iter_drop_leak_height_0() {
13991433
map.insert("d", D);
14001434
map.insert("e", D);
14011435

1402-
catch_unwind(move || drop(map.into_iter())).ok();
1436+
catch_unwind(move || drop(map.into_iter())).unwrap_err();
14031437

14041438
assert_eq!(DROPS.load(Ordering::SeqCst), 5);
14051439
}
@@ -1423,7 +1457,7 @@ fn test_into_iter_drop_leak_height_1() {
14231457
DROPS.store(0, Ordering::SeqCst);
14241458
PANIC_POINT.store(panic_point, Ordering::SeqCst);
14251459
let map: BTreeMap<_, _> = (0..size).map(|i| (i, D)).collect();
1426-
catch_unwind(move || drop(map.into_iter())).ok();
1460+
catch_unwind(move || drop(map.into_iter())).unwrap_err();
14271461
assert_eq!(DROPS.load(Ordering::SeqCst), size);
14281462
}
14291463
}

0 commit comments

Comments
 (0)