Skip to content

Commit

Permalink
Harden FamStructWrapper against integer overflows
Browse files Browse the repository at this point in the history
Add some additional checks around integer overflows when
multiplying/casting, to ensure that code added in the future does not
trigger potential overflow bugs.

Signed-off-by: Patrick Roy <roypat@amazon.co.uk>
  • Loading branch information
roypat authored and andreeaflorescu committed Sep 22, 2023
1 parent 3fb8f76 commit 5bf1061
Showing 1 changed file with 50 additions and 17 deletions.
67 changes: 50 additions & 17 deletions src/fam.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,14 @@ impl<T: Default + FamStruct> FamStructWrapper<T> {
///
/// Get the capacity required by mem_allocator in order to hold
/// the provided number of [`FamStruct::Entry`](trait.FamStruct.html#associatedtype.Entry).
fn mem_allocator_len(fam_len: usize) -> usize {
let wrapper_size_in_bytes = size_of::<T>() + fam_len * size_of::<T::Entry>();
(wrapper_size_in_bytes + size_of::<T>() - 1) / size_of::<T>()
/// Returns `None` if the required length would overflow usize.
fn mem_allocator_len(fam_len: usize) -> Option<usize> {
let wrapper_size_in_bytes =
size_of::<T>().checked_add(fam_len.checked_mul(size_of::<T::Entry>())?)?;

wrapper_size_in_bytes
.checked_add(size_of::<T>().checked_sub(1)?)?
.checked_div(size_of::<T>())
}

/// Convert `mem_allocator` len to FAM len.
Expand Down Expand Up @@ -206,7 +211,8 @@ impl<T: Default + FamStruct> FamStructWrapper<T> {
return Err(Error::SizeLimitExceeded);
}
let required_mem_allocator_capacity =
FamStructWrapper::<T>::mem_allocator_len(num_elements);
FamStructWrapper::<T>::mem_allocator_len(num_elements)
.ok_or(Error::SizeLimitExceeded)?;

let mut mem_allocator = Vec::with_capacity(required_mem_allocator_capacity);
mem_allocator.push(T::default());
Expand Down Expand Up @@ -326,17 +332,20 @@ impl<T: Default + FamStruct> FamStructWrapper<T> {
///
/// If the capacity is already reserved, this method doesn't do anything.
/// If not this will trigger a reallocation of the underlying buffer.
fn reserve(&mut self, additional: usize) {
fn reserve(&mut self, additional: usize) -> Result<(), Error> {
let desired_capacity = self.len() + additional;
if desired_capacity <= self.capacity() {
return;
return Ok(());
}

let current_mem_allocator_len = self.mem_allocator.len();
let required_mem_allocator_len = FamStructWrapper::<T>::mem_allocator_len(desired_capacity);
let required_mem_allocator_len = FamStructWrapper::<T>::mem_allocator_len(desired_capacity)
.ok_or(Error::SizeLimitExceeded)?;
let additional_mem_allocator_len = required_mem_allocator_len - current_mem_allocator_len;

self.mem_allocator.reserve(additional_mem_allocator_len);

Ok(())
}

/// Update the length of the FamStructWrapper.
Expand All @@ -352,7 +361,10 @@ impl<T: Default + FamStruct> FamStructWrapper<T> {
///
/// When len is greater than the max possible len it returns Error::SizeLimitExceeded.
fn set_len(&mut self, len: usize) -> Result<(), Error> {
let additional_elements: isize = len as isize - self.len() as isize;
let additional_elements = isize::try_from(len)
.and_then(|len| isize::try_from(self.len()).map(|self_len| len - self_len))
.map_err(|_| Error::SizeLimitExceeded)?;

// If len == self.len there's nothing to do.
if additional_elements == 0 {
return Ok(());
Expand All @@ -365,11 +377,12 @@ impl<T: Default + FamStruct> FamStructWrapper<T> {
return Err(Error::SizeLimitExceeded);
}
// Reserve additional capacity.
self.reserve(additional_elements as usize);
self.reserve(additional_elements as usize)?;
}

let current_mem_allocator_len = self.mem_allocator.len();
let required_mem_allocator_len = FamStructWrapper::<T>::mem_allocator_len(len);
let required_mem_allocator_len =
FamStructWrapper::<T>::mem_allocator_len(len).ok_or(Error::SizeLimitExceeded)?;
// Update the len of the `mem_allocator`.
// SAFETY: This is safe since enough capacity has been reserved.
unsafe {
Expand Down Expand Up @@ -445,9 +458,9 @@ impl<T: Default + FamStruct + PartialEq> PartialEq for FamStructWrapper<T> {
impl<T: Default + FamStruct> Clone for FamStructWrapper<T> {
fn clone(&self) -> Self {
// The number of entries (self.as_slice().len()) can't be > T::max_len() since `self` is a
// valid `FamStructWrapper`.
// valid `FamStructWrapper`. This makes the .unwrap() safe.
let required_mem_allocator_capacity =
FamStructWrapper::<T>::mem_allocator_len(self.as_slice().len());
FamStructWrapper::<T>::mem_allocator_len(self.as_slice().len()).unwrap();

let mut mem_allocator = Vec::with_capacity(required_mem_allocator_capacity);

Expand Down Expand Up @@ -581,6 +594,7 @@ macro_rules! generate_fam_struct_impl {
#[cfg(test)]
mod tests {
#![allow(clippy::undocumented_unsafe_blocks)]

#[cfg(feature = "with-serde")]
use serde_derive::{Deserialize, Serialize};

Expand Down Expand Up @@ -678,12 +692,30 @@ mod tests {
let fam_len = pair.0;
let mem_allocator_len = pair.1;
assert_eq!(
mem_allocator_len,
Some(mem_allocator_len),
MockFamStructWrapper::mem_allocator_len(fam_len)
);
}
}

#[repr(C)]
#[derive(Default, PartialEq)]
struct MockFamStructU8 {
pub len: u32,
pub padding: u32,
pub entries: __IncompleteArrayField<u8>,
}
generate_fam_struct_impl!(MockFamStructU8, u8, entries, u32, len, 100);
type MockFamStructWrapperU8 = FamStructWrapper<MockFamStructU8>;
#[test]
fn test_invalid_type_conversion() {
let mut adapter = MockFamStructWrapperU8::new(10).unwrap();
assert!(matches!(
adapter.set_len(0xffff_ffff_ffff_ff00),
Err(Error::SizeLimitExceeded)
));
}

#[test]
fn test_wrapper_len() {
for pair in MEM_ALLOCATOR_LEN_TO_FAM_LEN {
Expand Down Expand Up @@ -785,7 +817,7 @@ mod tests {
let num_elements = pair.0;
let required_mem_allocator_len = pair.1;

adapter.reserve(num_elements);
adapter.reserve(num_elements).unwrap();

assert!(adapter.mem_allocator.capacity() >= required_mem_allocator_len);
assert_eq!(0, adapter.len());
Expand All @@ -794,7 +826,7 @@ mod tests {

// test that when the capacity is already reserved, the method doesn't do anything
let current_capacity = adapter.capacity();
adapter.reserve(current_capacity - 1);
adapter.reserve(current_capacity - 1).unwrap();
assert_eq!(current_capacity, adapter.capacity());
}

Expand Down Expand Up @@ -831,7 +863,8 @@ mod tests {
assert_eq!(adapter.as_slice()[i], i as u32);
assert_eq!(adapter.len(), i + 1);
assert!(
adapter.mem_allocator.capacity() >= MockFamStructWrapper::mem_allocator_len(i + 1)
adapter.mem_allocator.capacity()
>= MockFamStructWrapper::mem_allocator_len(i + 1).unwrap()
);
}

Expand All @@ -858,7 +891,7 @@ mod tests {
assert_eq!(adapter.len(), num_retained_entries);
assert!(
adapter.mem_allocator.capacity()
>= MockFamStructWrapper::mem_allocator_len(num_retained_entries)
>= MockFamStructWrapper::mem_allocator_len(num_retained_entries).unwrap()
);
}

Expand Down

0 comments on commit 5bf1061

Please sign in to comment.