Skip to content

Commit 45c6230

Browse files
committed
Add HashSet::drain_filter method
Fixes #178.
1 parent 86641c8 commit 45c6230

File tree

3 files changed

+132
-19
lines changed

3 files changed

+132
-19
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
1313
- Added methods that allow re-using a `RawIter` for `RawDrain`,
1414
`RawIntoIter`, and `RawParIter`. (#175)
1515
- Added `reflect_remove` and `reflect_insert` to `RawIter`. (#175)
16+
- Added a `drain_filter` function to `HashSet`. (#179)
1617

1718
### Changed
1819
- Deprecated `RawTable::erase_no_drop` in favor of `erase` and `remove`. (#176)

src/map.rs

+36-18
Original file line numberDiff line numberDiff line change
@@ -603,14 +603,17 @@ impl<K, V, S> HashMap<K, V, S> {
603603
/// assert_eq!(drained.count(), 4);
604604
/// assert_eq!(map.len(), 4);
605605
/// ```
606+
#[cfg_attr(feature = "inline-more", inline)]
606607
pub fn drain_filter<F>(&mut self, f: F) -> DrainFilter<'_, K, V, F>
607608
where
608609
F: FnMut(&K, &mut V) -> bool,
609610
{
610611
DrainFilter {
611612
f,
612-
iter: unsafe { self.table.iter() },
613-
table: &mut self.table,
613+
inner: DrainFilterInner {
614+
iter: unsafe { self.table.iter() },
615+
table: &mut self.table,
616+
},
614617
}
615618
}
616619

@@ -1331,45 +1334,60 @@ where
13311334
F: FnMut(&K, &mut V) -> bool,
13321335
{
13331336
f: F,
1334-
iter: RawIter<(K, V)>,
1335-
table: &'a mut RawTable<(K, V)>,
1337+
inner: DrainFilterInner<'a, K, V>,
13361338
}
13371339

13381340
impl<'a, K, V, F> Drop for DrainFilter<'a, K, V, F>
13391341
where
13401342
F: FnMut(&K, &mut V) -> bool,
13411343
{
1344+
#[cfg_attr(feature = "inline-more", inline)]
13421345
fn drop(&mut self) {
1343-
struct DropGuard<'r, 'a, K, V, F>(&'r mut DrainFilter<'a, K, V, F>)
1344-
where
1345-
F: FnMut(&K, &mut V) -> bool;
1346-
1347-
impl<'r, 'a, K, V, F> Drop for DropGuard<'r, 'a, K, V, F>
1348-
where
1349-
F: FnMut(&K, &mut V) -> bool,
1350-
{
1351-
fn drop(&mut self) {
1352-
while let Some(_) = self.0.next() {}
1353-
}
1354-
}
13551346
while let Some(item) = self.next() {
1356-
let guard = DropGuard(self);
1347+
let guard = ConsumeAllOnDrop(self);
13571348
drop(item);
13581349
mem::forget(guard);
13591350
}
13601351
}
13611352
}
13621353

1354+
pub(super) struct ConsumeAllOnDrop<'a, T: Iterator>(pub &'a mut T);
1355+
1356+
impl<T: Iterator> Drop for ConsumeAllOnDrop<'_, T> {
1357+
#[cfg_attr(feature = "inline-more", inline)]
1358+
fn drop(&mut self) {
1359+
self.0.for_each(drop)
1360+
}
1361+
}
1362+
13631363
impl<K, V, F> Iterator for DrainFilter<'_, K, V, F>
13641364
where
13651365
F: FnMut(&K, &mut V) -> bool,
13661366
{
13671367
type Item = (K, V);
1368+
1369+
#[cfg_attr(feature = "inline-more", inline)]
13681370
fn next(&mut self) -> Option<Self::Item> {
1371+
self.inner.next(&mut self.f)
1372+
}
1373+
}
1374+
1375+
/// Portions of `DrainFilter` shared with `set::DrainFilter`
1376+
pub(super) struct DrainFilterInner<'a, K, V> {
1377+
pub iter: RawIter<(K, V)>,
1378+
pub table: &'a mut RawTable<(K, V)>,
1379+
}
1380+
1381+
impl<K, V> DrainFilterInner<'_, K, V> {
1382+
#[cfg_attr(feature = "inline-more", inline)]
1383+
pub(super) fn next<F>(&mut self, f: &mut F) -> Option<(K, V)>
1384+
where
1385+
F: FnMut(&K, &mut V) -> bool,
1386+
{
13691387
unsafe {
13701388
while let Some(item) = self.iter.next() {
13711389
let &mut (ref key, ref mut value) = item.as_mut();
1372-
if !(self.f)(key, value) {
1390+
if !f(key, value) {
13731391
return Some(self.table.remove(item));
13741392
}
13751393
}

src/set.rs

+95-1
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@ use core::borrow::Borrow;
44
use core::fmt;
55
use core::hash::{BuildHasher, Hash};
66
use core::iter::{Chain, FromIterator, FusedIterator};
7+
use core::mem;
78
use core::ops::{BitAnd, BitOr, BitXor, Sub};
89

9-
use super::map::{self, DefaultHashBuilder, HashMap, Keys};
10+
use super::map::{self, ConsumeAllOnDrop, DefaultHashBuilder, DrainFilterInner, HashMap, Keys};
1011

1112
// Future Optimization (FIXME!)
1213
// =============================
@@ -285,6 +286,39 @@ impl<T, S> HashSet<T, S> {
285286
self.map.retain(|k, _| f(k));
286287
}
287288

289+
/// Drains elements which are false under the given predicate,
290+
/// and returns an iterator over the removed items.
291+
///
292+
/// In other words, move all elements `e` such that `f(&e)` returns `false` out
293+
/// into another iterator.
294+
///
295+
/// When the returned DrainedFilter is dropped, the elements that don't satisfy
296+
/// the predicate are dropped from the set.
297+
///
298+
/// # Examples
299+
///
300+
/// ```
301+
/// use hashbrown::HashSet;
302+
///
303+
/// let mut set: HashSet<i32> = (0..8).collect();
304+
/// let drained = set.drain_filter(|&k| k % 2 == 0);
305+
/// assert_eq!(drained.count(), 4);
306+
/// assert_eq!(set.len(), 4);
307+
/// ```
308+
#[cfg_attr(feature = "inline-more", inline)]
309+
pub fn drain_filter<F>(&mut self, f: F) -> DrainFilter<'_, T, F>
310+
where
311+
F: FnMut(&T) -> bool,
312+
{
313+
DrainFilter {
314+
f,
315+
inner: DrainFilterInner {
316+
iter: unsafe { self.map.table.iter() },
317+
table: &mut self.map.table,
318+
},
319+
}
320+
}
321+
288322
/// Clears the set, removing all values.
289323
///
290324
/// # Examples
@@ -1185,6 +1219,21 @@ pub struct Drain<'a, K> {
11851219
iter: map::Drain<'a, K, ()>,
11861220
}
11871221

1222+
/// A draining iterator over entries of a `HashSet` which don't satisfy the predicate `f`.
1223+
///
1224+
/// This `struct` is created by the [`drain_filter`] method on [`HashSet`]. See its
1225+
/// documentation for more.
1226+
///
1227+
/// [`drain_filter`]: struct.HashSet.html#method.drain_filter
1228+
/// [`HashSet`]: struct.HashSet.html
1229+
pub struct DrainFilter<'a, K, F>
1230+
where
1231+
F: FnMut(&K) -> bool,
1232+
{
1233+
f: F,
1234+
inner: DrainFilterInner<'a, K, ()>,
1235+
}
1236+
11881237
/// A lazy iterator producing elements in the intersection of `HashSet`s.
11891238
///
11901239
/// This `struct` is created by the [`intersection`] method on [`HashSet`].
@@ -1365,6 +1414,34 @@ impl<K: fmt::Debug> fmt::Debug for Drain<'_, K> {
13651414
}
13661415
}
13671416

1417+
impl<'a, K, F> Drop for DrainFilter<'a, K, F>
1418+
where
1419+
F: FnMut(&K) -> bool,
1420+
{
1421+
#[cfg_attr(feature = "inline-more", inline)]
1422+
fn drop(&mut self) {
1423+
while let Some(item) = self.next() {
1424+
let guard = ConsumeAllOnDrop(self);
1425+
drop(item);
1426+
mem::forget(guard);
1427+
}
1428+
}
1429+
}
1430+
1431+
impl<K, F> Iterator for DrainFilter<'_, K, F>
1432+
where
1433+
F: FnMut(&K) -> bool,
1434+
{
1435+
type Item = K;
1436+
1437+
#[cfg_attr(feature = "inline-more", inline)]
1438+
fn next(&mut self) -> Option<Self::Item> {
1439+
let f = &mut self.f;
1440+
let (k, _) = self.inner.next(&mut |k, _| f(k))?;
1441+
Some(k)
1442+
}
1443+
}
1444+
13681445
impl<T, S> Clone for Intersection<'_, T, S> {
13691446
#[cfg_attr(feature = "inline-more", inline)]
13701447
fn clone(&self) -> Self {
@@ -1973,4 +2050,21 @@ mod test_set {
19732050
assert!(set.contains(&4));
19742051
assert!(set.contains(&6));
19752052
}
2053+
2054+
#[test]
2055+
fn test_drain_filter() {
2056+
{
2057+
let mut set: HashSet<i32> = (0..8).collect();
2058+
let drained = set.drain_filter(|&k| k % 2 == 0);
2059+
let mut out = drained.collect::<Vec<_>>();
2060+
out.sort_unstable();
2061+
assert_eq!(vec![1, 3, 5, 7], out);
2062+
assert_eq!(set.len(), 4);
2063+
}
2064+
{
2065+
let mut set: HashSet<i32> = (0..8).collect();
2066+
drop(set.drain_filter(|&k| k % 2 == 0));
2067+
assert_eq!(set.len(), 4, "Removes non-matching items on drop");
2068+
}
2069+
}
19762070
}

0 commit comments

Comments
 (0)