Skip to content

Commit

Permalink
Add the ability to recover the original HashTable from an entry
Browse files Browse the repository at this point in the history
  • Loading branch information
Amanieu committed Sep 22, 2023
1 parent 142e536 commit 98e2f78
Showing 1 changed file with 59 additions and 43 deletions.
102 changes: 59 additions & 43 deletions src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pub struct HashTable<T, A = Global>
where
A: Allocator,
{
pub(crate) table: RawTable<T, A>,
pub(crate) raw: RawTable<T, A>,
}

impl<T> HashTable<T, Global> {
Expand All @@ -65,7 +65,7 @@ impl<T> HashTable<T, Global> {
/// ```
pub const fn new() -> Self {
Self {
table: RawTable::new(),
raw: RawTable::new(),
}
}

Expand All @@ -84,7 +84,7 @@ impl<T> HashTable<T, Global> {
/// ```
pub fn with_capacity(capacity: usize) -> Self {
Self {
table: RawTable::with_capacity(capacity),
raw: RawTable::with_capacity(capacity),
}
}
}
Expand Down Expand Up @@ -133,7 +133,7 @@ where
/// ```
pub const fn new_in(alloc: A) -> Self {
Self {
table: RawTable::new_in(alloc),
raw: RawTable::new_in(alloc),
}
}

Expand Down Expand Up @@ -182,13 +182,13 @@ where
/// ```
pub fn with_capacity_in(capacity: usize, alloc: A) -> Self {
Self {
table: RawTable::with_capacity_in(capacity, alloc),
raw: RawTable::with_capacity_in(capacity, alloc),
}
}

/// Returns a reference to the underlying allocator.
pub fn allocator(&self) -> &A {
self.table.allocator()
self.raw.allocator()
}

/// Returns a reference to an entry in the table with the given hash and
Expand Down Expand Up @@ -222,7 +222,7 @@ where
/// # }
/// ```
pub fn find(&self, hash: u64, eq: impl FnMut(&T) -> bool) -> Option<&T> {
self.table
self.raw
.find(hash, eq)
.map(|bucket| unsafe { bucket.as_ref() })
}
Expand Down Expand Up @@ -263,7 +263,7 @@ where
/// # }
/// ```
pub fn find_mut(&mut self, hash: u64, eq: impl FnMut(&T) -> bool) -> Option<&mut T> {
self.table
self.raw
.find(hash, eq)
.map(|bucket| unsafe { bucket.as_mut() })
}
Expand Down Expand Up @@ -292,7 +292,7 @@ where
/// let hasher = BuildHasherDefault::<AHasher>::default();
/// let hasher = |val: &_| hasher.hash_one(val);
/// table.insert_unchecked(hasher(&1), (1, "a"), |val| hasher(&val.0));
/// if let Some(entry) = table.find_entry(hasher(&1), |val| val.0 == 1) {
/// if let Ok(entry) = table.find_entry(hasher(&1), |val| val.0 == 1) {
/// entry.remove();
/// }
/// assert_eq!(table.find(hasher(&1), |val| val.0 == 1), None);
Expand All @@ -306,12 +306,15 @@ where
&mut self,
hash: u64,
eq: impl FnMut(&T) -> bool,
) -> Option<OccupiedEntry<'_, T, A>> {
self.table.find(hash, eq).map(|bucket| OccupiedEntry {
hash,
bucket,
table: &mut self.table,
})
) -> Result<OccupiedEntry<'_, T, A>, &mut Self> {
match self.raw.find(hash, eq) {
Some(bucket) => Ok(OccupiedEntry {
hash,
bucket,
table: self,
}),
None => Err(self),
}
}

/// Returns an `Entry` for an entry in the table with the given hash
Expand Down Expand Up @@ -365,16 +368,16 @@ where
eq: impl FnMut(&T) -> bool,
hasher: impl Fn(&T) -> u64,
) -> Entry<'_, T, A> {
match self.table.find_or_find_insert_slot(hash, eq, hasher) {
match self.raw.find_or_find_insert_slot(hash, eq, hasher) {
Ok(bucket) => Entry::Occupied(OccupiedEntry {
hash,
bucket,
table: &mut self.table,
table: self,
}),
Err(insert_slot) => Entry::Vacant(VacantEntry {
hash,
insert_slot,
table: &mut self.table,
table: self,
}),
}
}
Expand All @@ -393,11 +396,11 @@ where
value: T,
hasher: impl Fn(&T) -> u64,
) -> OccupiedEntry<'_, T, A> {
let bucket = self.table.insert(hash, value, hasher);
let bucket = self.raw.insert(hash, value, hasher);
OccupiedEntry {
hash,
bucket,
table: &mut self.table,
table: self,
}
}

Expand Down Expand Up @@ -425,7 +428,7 @@ where
/// # }
/// ```
pub fn clear(&mut self) {
self.table.clear();
self.raw.clear();
}

/// Shrinks the capacity of the table as much as possible. It will drop
Expand Down Expand Up @@ -459,7 +462,7 @@ where
/// # }
/// ```
pub fn shrink_to_fit(&mut self, hasher: impl Fn(&T) -> u64) {
self.table.shrink_to(self.len(), hasher)
self.raw.shrink_to(self.len(), hasher)
}

/// Shrinks the capacity of the table with a lower limit. It will drop
Expand Down Expand Up @@ -498,7 +501,7 @@ where
/// # }
/// ```
pub fn shrink_to(&mut self, min_capacity: usize, hasher: impl Fn(&T) -> u64) {
self.table.shrink_to(min_capacity, hasher);
self.raw.shrink_to(min_capacity, hasher);
}

/// Reserves capacity for at least `additional` more elements to be inserted
Expand Down Expand Up @@ -538,7 +541,7 @@ where
/// # }
/// ```
pub fn reserve(&mut self, additional: usize, hasher: impl Fn(&T) -> u64) {
self.table.reserve(additional, hasher)
self.raw.reserve(additional, hasher)
}

/// Tries to reserve capacity for at least `additional` more elements to be inserted
Expand Down Expand Up @@ -579,7 +582,7 @@ where
additional: usize,
hasher: impl Fn(&T) -> u64,
) -> Result<(), TryReserveError> {
self.table.try_reserve(additional, hasher)
self.raw.try_reserve(additional, hasher)
}

/// Returns the number of elements the table can hold without reallocating.
Expand All @@ -592,7 +595,7 @@ where
/// assert!(table.capacity() >= 100);
/// ```
pub fn capacity(&self) -> usize {
self.table.capacity()
self.raw.capacity()
}

/// Returns the number of elements in the table.
Expand All @@ -619,7 +622,7 @@ where
/// # }
/// ```
pub fn len(&self) -> usize {
self.table.len()
self.raw.len()
}

/// Returns `true` if the set contains no elements.
Expand All @@ -646,7 +649,7 @@ where
/// # }
/// ```
pub fn is_empty(&self) -> bool {
self.table.is_empty()
self.raw.is_empty()
}

/// An iterator visiting all elements in arbitrary order.
Expand Down Expand Up @@ -679,7 +682,7 @@ where
/// ```
pub fn iter(&self) -> Iter<'_, T> {
Iter {
inner: unsafe { self.table.iter() },
inner: unsafe { self.raw.iter() },
marker: PhantomData,
}
}
Expand Down Expand Up @@ -731,7 +734,7 @@ where
/// ```
pub fn iter_mut(&mut self) -> IterMut<'_, T> {
IterMut {
inner: unsafe { self.table.iter() },
inner: unsafe { self.raw.iter() },
marker: PhantomData,
}
}
Expand Down Expand Up @@ -766,9 +769,9 @@ where
pub fn retain(&mut self, mut f: impl FnMut(&mut T) -> bool) {
// Here we only use `iter` as a temporary, preventing use-after-free
unsafe {
for item in self.table.iter() {
for item in self.raw.iter() {
if !f(item.as_mut()) {
self.table.erase(item);
self.raw.erase(item);
}
}
}
Expand Down Expand Up @@ -807,7 +810,7 @@ where
/// ```
pub fn drain(&mut self) -> Drain<'_, T, A> {
Drain {
inner: self.table.drain(),
inner: self.raw.drain(),
}
}

Expand Down Expand Up @@ -858,8 +861,8 @@ where
ExtractIf {
f,
inner: RawExtractIf {
iter: unsafe { self.table.iter() },
table: &mut self.table,
iter: unsafe { self.raw.iter() },
table: &mut self.raw,
},
}
}
Expand Down Expand Up @@ -922,7 +925,7 @@ where
hashes: [u64; N],
eq: impl FnMut(usize, &T) -> bool,
) -> Option<[&'_ mut T; N]> {
self.table.get_many_mut(hashes, eq)
self.raw.get_many_mut(hashes, eq)
}

/// Attempts to get mutable references to `N` values in the map at once, without validating that
Expand Down Expand Up @@ -992,7 +995,7 @@ where
hashes: [u64; N],
eq: impl FnMut(usize, &T) -> bool,
) -> Option<[&'_ mut T; N]> {
self.table.get_many_unchecked_mut(hashes, eq)
self.raw.get_many_unchecked_mut(hashes, eq)
}
}

Expand All @@ -1005,7 +1008,7 @@ where

fn into_iter(self) -> IntoIter<T, A> {
IntoIter {
inner: self.table.into_iter(),
inner: self.raw.into_iter(),
}
}
}
Expand Down Expand Up @@ -1040,7 +1043,7 @@ where
{
fn default() -> Self {
Self {
table: Default::default(),
raw: Default::default(),
}
}
}
Expand All @@ -1052,7 +1055,7 @@ where
{
fn clone(&self) -> Self {
Self {
table: self.table.clone(),
raw: self.raw.clone(),
}
}
}
Expand Down Expand Up @@ -1429,7 +1432,7 @@ where
{
hash: u64,
bucket: Bucket<T>,
table: &'a mut RawTable<T, A>,
table: &'a mut HashTable<T, A>,
}

unsafe impl<T, A> Send for OccupiedEntry<'_, T, A>
Expand Down Expand Up @@ -1496,7 +1499,7 @@ where
/// # }
/// ```
pub fn remove(self) -> (T, VacantEntry<'a, T, A>) {
let (val, slot) = unsafe { self.table.remove(self.bucket) };
let (val, slot) = unsafe { self.table.raw.remove(self.bucket) };
(
val,
VacantEntry {
Expand Down Expand Up @@ -1642,6 +1645,12 @@ where
pub fn into_mut(self) -> &'a mut T {
unsafe { self.bucket.as_mut() }
}

/// Converts the OccupiedEntry into a mutable reference to the underlying
/// table.
pub fn into_table(self) -> &'a mut HashTable<T, A> {
self.table
}
}

/// A view into a vacant entry in a `HashTable`.
Expand Down Expand Up @@ -1689,7 +1698,7 @@ where
{
hash: u64,
insert_slot: InsertSlot,
table: &'a mut RawTable<T, A>,
table: &'a mut HashTable<T, A>,
}

impl<T: fmt::Debug, A: Allocator> fmt::Debug for VacantEntry<'_, T, A> {
Expand Down Expand Up @@ -1737,6 +1746,7 @@ where
pub fn insert(self, value: T) -> OccupiedEntry<'a, T, A> {
let bucket = unsafe {
self.table
.raw
.insert_in_slot(self.hash, self.insert_slot, value)
};
OccupiedEntry {
Expand All @@ -1745,6 +1755,12 @@ where
table: self.table,
}
}

/// Converts the OccupiedEntry into a mutable reference to the underlying
/// table.
pub fn into_table(self) -> &'a mut HashTable<T, A> {
self.table
}
}

/// An iterator over the entries of a `HashTable` in arbitrary order.
Expand Down

0 comments on commit 98e2f78

Please sign in to comment.