Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Audit most uses of unsafe #566

Merged
merged 1 commit into from
Mar 31, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/bitset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,13 @@ macro_rules! define_bit_join {
type Type = Index;
type Value = ();
type Mask = $bitset;

// SAFETY: This just moves a `BitSet`; invariants of `Join` are fulfilled, since `Self::Value` cannot be mutated.
unsafe fn open(self) -> (Self::Mask, Self::Value) {
(self, ())
}

// SAFETY: No unsafe code and no invariants to meet.
unsafe fn get(_: &mut Self::Value, id: Index) -> Self::Type {
id
}
Expand Down
13 changes: 13 additions & 0 deletions src/changeset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,12 @@ impl<T> ChangeSet<T> {
T: AddAssign,
{
if self.mask.contains(entity.id()) {
// SAFETY: we checked the mask, thus it's safe to call
unsafe {
*self.inner.get_mut(entity.id()) += value;
}
} else {
// SAFETY: we checked the mask, thus it's safe to call
unsafe {
self.inner.insert(entity.id(), value);
}
Expand All @@ -73,6 +75,7 @@ impl<T> ChangeSet<T> {
/// Clear the changeset
pub fn clear(&mut self) {
for id in &self.mask {
// SAFETY: we checked the mask, thus it's safe to call
unsafe {
self.inner.remove(id);
}
Expand Down Expand Up @@ -110,10 +113,13 @@ impl<'a, T> Join for &'a mut ChangeSet<T> {
type Value = &'a mut DenseVecStorage<T>;
type Mask = &'a BitSet;

// SAFETY: No unsafe code and no invariants to meet.
unsafe fn open(self) -> (Self::Mask, Self::Value) {
(&self.mask, &mut self.inner)
}

// SAFETY: No unsafe code and no invariants to meet.
// `DistinctStorage` invariants are also met, but no `ParJoin` implementation exists yet.
unsafe fn get(v: &mut Self::Value, id: Index) -> Self::Type {
let value: *mut Self::Value = v as *mut Self::Value;
(*value).get_mut(id)
Expand All @@ -125,24 +131,31 @@ impl<'a, T> Join for &'a ChangeSet<T> {
type Value = &'a DenseVecStorage<T>;
type Mask = &'a BitSet;

// SAFETY: No unsafe code and no invariants to meet.
unsafe fn open(self) -> (Self::Mask, Self::Value) {
(&self.mask, &self.inner)
}

// SAFETY: No unsafe code and no invariants to meet.
// `DistinctStorage` invariants are also met, but no `ParJoin` implementation exists yet.
unsafe fn get(value: &mut Self::Value, id: Index) -> Self::Type {
value.get(id)
}
}

/// A `Join` implementation for `ChangeSet` that simply removes all the entries on a call to `get`.
impl<T> Join for ChangeSet<T> {
type Type = T;
type Value = DenseVecStorage<T>;
type Mask = BitSet;

// SAFETY: No unsafe code and no invariants to meet.
unsafe fn open(self) -> (Self::Mask, Self::Value) {
(self.mask, self.inner)
}

// SAFETY: No unsafe code and no invariants to meet.
// `DistinctStorage` invariants are also met, but no `ParJoin` implementation exists yet.
unsafe fn get(value: &mut Self::Value, id: Index) -> Self::Type {
value.remove(id)
}
Expand Down
40 changes: 40 additions & 0 deletions src/join/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,13 +223,20 @@ pub trait Join {

/// Open this join by returning the mask and the storages.
///
/// # Safety
///
/// This is unsafe because implementations of this trait can permit
/// the `Value` to be mutated independently of the `Mask`.
/// If the `Mask` does not correctly report the status of the `Value`
/// then illegal memory access can occur.
unsafe fn open(self) -> (Self::Mask, Self::Value);

/// Get a joined component value by a given index.
///
/// # Safety
///
/// * A call to `get` must be preceded by a check if `id` is part of `Self::Mask`
/// * The implementation of this method may use unsafe code, but has no invariants to meet
unsafe fn get(value: &mut Self::Value, id: Index) -> Self::Type;

/// If this `Join` typically returns all indices in the mask, then iterating over only it
Expand Down Expand Up @@ -261,10 +268,15 @@ where
type Type = Option<<T as Join>::Type>;
type Value = (<T as Join>::Mask, <T as Join>::Value);
type Mask = BitSetAll;

// SAFETY: This wraps another implementation of `open`, making it dependent on `J`'s correctness.
// We can safely assume `J` is valid, thus this must be valid, too. No invariants to meet.
unsafe fn open(self) -> (Self::Mask, Self::Value) {
let (mask, value) = self.0.open();
(BitSetAll, (mask, value))
}

// SAFETY: No invariants to meet and the unsafe code checks the mask, thus fulfills the requirements for calling `get`
unsafe fn get((mask, value): &mut Self::Value, id: Index) -> Self::Type {
if mask.contains(id) {
Some(<T as Join>::get(value, id))
Expand Down Expand Up @@ -293,6 +305,7 @@ impl<J: Join> JoinIter<J> {
println!("WARNING: `Join` possibly iterating through all indices, you might've made a join with all `MaybeJoin`s, which is unbounded in length.");
}

// SAFETY: We do not swap out the mask or the values, nor do we allow it by exposing them.
let (keys, values) = unsafe { j.open() };
JoinIter {
keys: keys.iter(),
Expand Down Expand Up @@ -353,6 +366,7 @@ impl<J: Join> JoinIter<J> {
/// ```
pub fn get(&mut self, entity: Entity, entities: &Entities) -> Option<J::Type> {
if self.keys.contains(entity.id()) && entities.is_alive(entity) {
// SAFETY: the mask (`keys`) is checked as specified in the docs of `get`.
Some(unsafe { J::get(&mut self.values, entity.id()) })
} else {
None
Expand All @@ -367,6 +381,7 @@ impl<J: Join> JoinIter<J> {
/// so the caller should ensure it instead.
pub fn get_unchecked(&mut self, index: Index) -> Option<J::Type> {
if self.keys.contains(index) {
// SAFETY: the mask (`keys`) is checked as specified in the docs of `get`.
Some(unsafe { J::get(&mut self.values, index) })
} else {
None
Expand All @@ -378,6 +393,8 @@ impl<J: Join> std::iter::Iterator for JoinIter<J> {
type Item = J::Type;

fn next(&mut self) -> Option<J::Type> {
// SAFETY: since `idx` is yielded from `keys` (the mask), it is necessarily a part of it.
// Thus, requirements are fulfilled for calling `get`.
self.keys
.next()
.map(|idx| unsafe { J::get(&mut self.values, idx) })
Expand All @@ -395,6 +412,9 @@ macro_rules! define_open {
type Value = ($($from::Value),*,);
type Mask = <($($from::Mask,)*) as BitAnd>::Value;
#[allow(non_snake_case)]

// SAFETY: While we do expose the mask and the values and therefore would allow swapping them,
// this method is `unsafe` and relies on the same invariants.
unsafe fn open(self) -> (Self::Mask, Self::Value) {
let ($($from,)*) = self;
let ($($from,)*) = ($($from.open(),)*);
Expand All @@ -404,6 +424,8 @@ macro_rules! define_open {
)
}

// SAFETY: No invariants to meet and `get` is safe to call as the caller must have checked the mask,
// which only has a key that exists in all of the storages.
#[allow(non_snake_case)]
unsafe fn get(v: &mut Self::Value, i: Index) -> Self::Type {
let &mut ($(ref mut $from,)*) = v;
Expand All @@ -417,6 +439,10 @@ macro_rules! define_open {
unconstrained
}
}

// SAFETY: This is safe to implement since all components implement `ParJoin`.
// If the access of every individual `get` leads to disjoint memory access, calling
// all of them after another does in no case lead to access of common memory.
#[cfg(feature = "parallel")]
unsafe impl<$($from,)*> ParJoin for ($($from),*,)
where $($from: ParJoin),*,
Expand Down Expand Up @@ -463,10 +489,15 @@ macro_rules! immutable_resource_join {
type Type = <&'a T as Join>::Type;
type Value = <&'a T as Join>::Value;
type Mask = <&'a T as Join>::Mask;

// SAFETY: This only wraps `T` and, while exposing the mask and the values,
// requires the same invariants as the original implementation and is thus safe.
unsafe fn open(self) -> (Self::Mask, Self::Value) {
self.deref().open()
}

// SAFETY: The mask of `Self` and `T` are identical, thus a check to `Self`'s mask (which is required)
// is equal to a check of `T`'s mask, which makes `get` safe to call.
unsafe fn get(v: &mut Self::Value, i: Index) -> Self::Type {
<&'a T as Join>::get(v, i)
}
Expand All @@ -477,6 +508,8 @@ macro_rules! immutable_resource_join {
}
}

// SAFETY: This is just a wrapper of `T`'s implementation for `ParJoin` and can
// in no case lead to other memory access patterns.
#[cfg(feature = "parallel")]
unsafe impl<'a, 'b, T> ParJoin for &'a $ty
where
Expand All @@ -498,10 +531,15 @@ macro_rules! mutable_resource_join {
type Type = <&'a mut T as Join>::Type;
type Value = <&'a mut T as Join>::Value;
type Mask = <&'a mut T as Join>::Mask;

// SAFETY: This only wraps `T` and, while exposing the mask and the values,
// requires the same invariants as the original implementation and is thus safe.
unsafe fn open(self) -> (Self::Mask, Self::Value) {
self.deref_mut().open()
}

// SAFETY: The mask of `Self` and `T` are identical, thus a check to `Self`'s mask (which is required)
// is equal to a check of `T`'s mask, which makes `get_mut` safe to call.
unsafe fn get(v: &mut Self::Value, i: Index) -> Self::Type {
<&'a mut T as Join>::get(v, i)
}
Expand All @@ -512,6 +550,8 @@ macro_rules! mutable_resource_join {
}
}

// SAFETY: This is just a wrapper of `T`'s implementation for `ParJoin` and can
// in no case lead to other memory access patterns.
#[cfg(feature = "parallel")]
unsafe impl<'a, 'b, T> ParJoin for &'a mut $ty
where
Expand Down
18 changes: 18 additions & 0 deletions src/join/par_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@ use join::Join;
/// The purpose of the `ParJoin` trait is to provide a way
/// to access multiple storages in parallel at the same time with
/// the merged bit set.
///
/// # Safety
///
/// The implementation of `ParallelIterator` for `ParJoin` makes multiple assumptions on the structure of `Self`.
/// In particular, `<Self as Join>::get` must be callable from multiple threads, simultaneously, without mutating
/// values not exclusively associated with `id`.
// NOTE: This is currently unspecified behavior. It seems very unlikely that it breaks in the future,
// but technically it's not specified as valid Rust code.
pub unsafe trait ParJoin: Join {
/// Create a joined parallel iterator over the contents.
fn par_join(self) -> JoinParIter<Self>
Expand Down Expand Up @@ -45,6 +53,8 @@ where
let (keys, values) = unsafe { self.0.open() };
// Create a bit producer which splits on up to three levels
let producer = BitProducer((&keys).iter(), 3);
// HACK: use `UnsafeCell` to share `values` between threads;
// this is the unspecified behavior referred to above.
let values = UnsafeCell::new(values);

bridge_unindexed(JoinProducer::<J>::new(producer, &values), consumer)
Expand Down Expand Up @@ -74,6 +84,14 @@ where
}
}

// SAFETY: `Send` is safe to implement if all components of `Self` are logically `Send`.
// `keys` already has `Send` implemented, thus no reasoning is required.
// `values` is a reference to an `UnsafeCell` wrapping `J::Value`;
// `J::Value` is constrained to implement `Send`.
// `UnsafeCell` provides interior mutability, but the specification of it allows sharing
// as long as access does not happen simultaneously; this makes it generally safe to `Send`,
// but we are accessing it simultaneously, which is technically not allowed.
// Also see https://github.com/slide-rs/specs/issues/220
unsafe impl<'a, J> Send for JoinProducer<'a, J>
where
J: Join + Send,
Expand Down
2 changes: 2 additions & 0 deletions src/storage/drain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ where
type Value = &'a mut MaskedStorage<T>;
type Mask = BitSet;

// SAFETY: No invariants to meet and no unsafe code.
unsafe fn open(self) -> (Self::Mask, Self::Value) {
let mask = self.data.mask.clone();

(mask, self.data)
}

// SAFETY: No invariants to meet and no unsafe code.
unsafe fn get(value: &mut Self::Value, id: Index) -> T {
value.remove(id).expect("Tried to access same index twice")
}
Expand Down
12 changes: 12 additions & 0 deletions src/storage/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ where
if self.entities.is_alive(e) {
unsafe {
let entries = self.entries();
// SAFETY: This is safe since we're not swapping out the mask or the values.
let (_, mut value): (BitSetAll, _) = entries.open();
// SAFETY: We did check the mask, because the mask is `BitSetAll` and every index is part of it.
Ok(Entries::get(&mut value, e.id()))
}
} else {
Expand Down Expand Up @@ -132,10 +134,13 @@ where
type Value = &'a mut Storage<'b, T, D>;
type Mask = BitSetAll;

// SAFETY: No invariants to meet and no unsafe code.
unsafe fn open(self) -> (Self::Mask, Self::Value) {
(BitSetAll, self.0)
}

// SAFETY: We are lengthening the lifetime of `value` to `'a`;
// TODO: how to prove this is safe?
unsafe fn get(value: &mut Self::Value, id: Index) -> Self::Type {
// This is HACK. See implementation of Join for &'a mut Storage<'e, T, D> for
// details why it is necessary.
Expand Down Expand Up @@ -172,6 +177,8 @@ where
{
/// Get a reference to the component associated with the entity.
pub fn get(&self) -> &T {
// SAFETY: This is safe since `OccupiedEntry` is only constructed
// after checking the mask.
unsafe { self.storage.data.inner.get(self.id) }
}
}
Expand All @@ -183,12 +190,16 @@ where
{
/// Get a mutable reference to the component associated with the entity.
pub fn get_mut(&mut self) -> &mut T {
// SAFETY: This is safe since `OccupiedEntry` is only constructed
// after checking the mask.
unsafe { self.storage.data.inner.get_mut(self.id) }
}

/// Converts the `OccupiedEntry` into a mutable reference bounded by
/// the storage's lifetime.
pub fn into_mut(self) -> &'a mut T {
// SAFETY: This is safe since `OccupiedEntry` is only constructed
// after checking the mask.
unsafe { self.storage.data.inner.get_mut(self.id) }
}

Expand Down Expand Up @@ -218,6 +229,7 @@ where
/// Inserts a value into the storage.
pub fn insert(self, component: T) -> &'a mut T {
self.storage.data.mask.add(self.id);
// SAFETY: This is safe since we added `self.id` to the mask.
unsafe {
self.storage.data.inner.insert(self.id, component);
self.storage.data.inner.get_mut(self.id)
Expand Down
Loading