Skip to content

Commit

Permalink
Merge pull request #57 from doubledup/checked-sub
Browse files Browse the repository at this point in the history
Note possible checked_sub cases
  • Loading branch information
ralexstokes authored Jun 10, 2023
2 parents e4a1118 + 8bde6b5 commit d72b6f9
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 44 deletions.
18 changes: 16 additions & 2 deletions ssz-rs-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,14 +193,25 @@ fn derive_deserialize_impl(data: &Data) -> TokenStream {
Some(field_name) => quote_spanned! { f.span() =>
let bytes_read = if <#field_type>::is_variable_size() {
let end = start + #BYTES_PER_LENGTH_OFFSET;

let target = encoding.get(start..end).ok_or_else(||
ssz_rs::DeserializeError::ExpectedFurtherInput {
provided: encoding.len() - start,
expected: #BYTES_PER_LENGTH_OFFSET,
}
)?;
let next_offset = u32::deserialize(target)?;
offsets.push((#i, next_offset as usize));
let next_offset = u32::deserialize(target)? as usize;

if let Some((_, previous_offset)) = offsets.last() {
if next_offset < *previous_offset {
return Err(DeserializeError::OffsetNotIncreasing {
start: *previous_offset,
end: next_offset,
})
}
}

offsets.push((#i, next_offset));

#BYTES_PER_LENGTH_OFFSET
} else {
Expand Down Expand Up @@ -248,6 +259,9 @@ fn derive_deserialize_impl(data: &Data) -> TokenStream {
}
)?;
container.__ssz_rs_set_by_index(index, target)?;

// SAFETY: checked subtraction is unnecessary,
// as offsets are increasing; qed
total_bytes_read += end - start;
}

Expand Down
14 changes: 14 additions & 0 deletions ssz-rs/src/bitlist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ impl<const N: usize> fmt::Debug for Bitlist<N> {
let value = i32::from(*bit);
write!(f, "{value}")?;
bits_written += 1;
// SAFETY: checked subtraction is unnecessary, as len >= 1 when this for loop runs; qed
if bits_written % 4 == 0 && index != len - 1 {
write!(f, "_")?;
}
Expand Down Expand Up @@ -110,6 +111,7 @@ impl<const N: usize> Bitlist<N> {
*last |= 1u8 << marker_index;
}
}
// SAFETY: checked subtraction is unnecessary, as buffer.len() > start_len; qed
Ok(buffer.len() - start_len)
}
}
Expand Down Expand Up @@ -161,10 +163,17 @@ impl<const N: usize> Deserialize for Bitlist<N> {
}

let (last_byte, prefix) = encoding.split_last().unwrap();
if *last_byte == 0u8 {
return Err(DeserializeError::InvalidByte(*last_byte))
}

let mut result = BitlistInner::from_slice(prefix);
let last = BitlistInner::from_element(*last_byte);

// validate bit length satisfies bound `N`
// SAFETY: checked subtraction is unnecessary,
// as last_byte != 0, so last.trailing_zeros <= 7; qed
// therefore: bit_length >= 1
let bit_length = 8 - last.trailing_zeros();
let additional_members = bit_length - 1; // skip marker bit
let total_members = result.len() + additional_members;
Expand Down Expand Up @@ -282,6 +291,11 @@ mod tests {
)
.unwrap();
assert_eq!(result, expected);

let bytes = vec![24u8, 0u8];
let result = Bitlist::<COUNT>::deserialize(&bytes).expect_err("test data is incorrect");
let expected = DeserializeError::InvalidByte(0u8);
assert_eq!(result.to_string(), expected.to_string());
}

#[test]
Expand Down
1 change: 1 addition & 0 deletions ssz-rs/src/bitvector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ impl<const N: usize> fmt::Debug for Bitvector<N> {
let value = i32::from(*bit);
write!(f, "{value}")?;
bits_written += 1;
// SAFETY: checked subtraction is unnecessary, as len >= 1 for bitvectors; qed
if bits_written % 4 == 0 && index != len - 1 {
write!(f, "_")?;
}
Expand Down
1 change: 1 addition & 0 deletions ssz-rs/src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ where
if remainder != 0 {
return Err(DeserializeError::AdditionalInput {
provided: encoding.len(),
// SAFETY: checked subtraction is unnecessary, as encoding.len() > remainder; qed
expected: encoding.len() - remainder,
})
}
Expand Down
107 changes: 65 additions & 42 deletions ssz-rs/src/merkleization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,15 @@ impl Display for MerkleizationError {
#[cfg(feature = "std")]
impl std::error::Error for MerkleizationError {}

// Ensures `buffer` can be exactly broken up into `BYTES_PER_CHUNK` chunks of bytes
// via padding any partial chunks at the end of `buffer`
pub fn pack_bytes(buffer: &mut Vec<u8>) {
let data_len = buffer.len();
if data_len % BYTES_PER_CHUNK != 0 {
let bytes_to_pad = BYTES_PER_CHUNK - data_len % BYTES_PER_CHUNK;
let pad = vec![0u8; bytes_to_pad];
buffer.extend_from_slice(&pad);
let incomplete_chunk_len = buffer.len() % BYTES_PER_CHUNK;
if incomplete_chunk_len != 0 {
// SAFETY: checked subtraction is unnecessary,
// as BYTES_PER_CHUNK > incomplete_chunk_len; qed
let bytes_to_pad = BYTES_PER_CHUNK - incomplete_chunk_len;
buffer.resize(buffer.len() + bytes_to_pad, 0);
}
}

Expand Down Expand Up @@ -99,6 +102,10 @@ include!(concat!(env!("OUT_DIR"), "/context.rs"));
/// of two and this can be quite large for some types. "Zero" subtrees are virtualized to avoid the
/// memory and computation cost of large trees with partially empty leaves.
///
/// The implementation approach treats `chunks` as the bottom layer of a perfect binary tree
/// and for each height performs the hashing required to compute the parent layer in place.
/// This process is repated until the root is computed.
///
/// Invariant: `chunks.len() % BYTES_PER_CHUNK == 0`
/// Invariant: `leaf_count.next_power_of_two() == leaf_count`
/// Invariant: `leaf_count != 0`
Expand All @@ -107,80 +114,91 @@ fn merkleize_chunks_with_virtual_padding(
chunks: &[u8],
leaf_count: usize,
) -> Result<Node, MerkleizationError> {
let chunk_count = chunks.len() / BYTES_PER_CHUNK;

let mut hasher = Sha256::new();
debug_assert!(chunks.len() % BYTES_PER_CHUNK == 0);
// NOTE: This also asserts that leaf_count != 0
debug_assert!(leaf_count.next_power_of_two() == leaf_count);
// SAFETY: this holds as long as leaf_count != 0 and usize is no longer than u64
debug_assert!((leaf_count.trailing_zeros() as usize) < MAX_MERKLE_TREE_DEPTH);

let chunk_count = chunks.len() / BYTES_PER_CHUNK;
let height = leaf_count.trailing_zeros() + 1;

if chunk_count == 0 {
// SAFETY: checked subtraction is unnecessary, as height >= 1; qed
let depth = height - 1;
// SAFETY: index is safe while depth == leaf_count.trailing_zeros() < MAX_MERKLE_TREE_DEPTH;
// qed
return Ok(CONTEXT[depth as usize].try_into().expect("can produce a single root chunk"))
}

let mut layer = chunks.to_vec();
// SAFETY: checked subtraction is unnecessary, as we return early when chunk_count == 0; qed
let mut last_index = chunk_count - 1;
let mut hasher = Sha256::new();
// for each layer of the tree, starting from the bottom and walking up to the root:
for k in (1..height).rev() {
// for each pair of nodes in this layer:
for i in (0..2usize.pow(k)).step_by(2) {
let parent_index = i / 2;
match i.cmp(&last_index) {
let (parent, left, right) = match i.cmp(&last_index) {
Ordering::Less => {
// SAFETY: index is safe because (i+1)*BYTES_PER_CHUNK < layer.len():
// i < last_index == chunk_count - 1 == (layer.len() / BYTES_PER_CHUNK) - 1
// so i+1 < layer.len() / BYTES_PER_CHUNK
// so (i+1)*BYTES_PER_CHUNK < layer.len(); qed
let focus =
&mut layer[parent_index * BYTES_PER_CHUNK..(i + 2) * BYTES_PER_CHUNK];
// SAFETY: checked subtraction is unnecessary:
// focus.len() = (i + 2 - parent_index) * BYTES_PER_CHUNK
// and
// i >= parent_index
// so focus.len() >= 2 * BYTES_PER_CHUNK; qed
let children_index = focus.len() - 2 * BYTES_PER_CHUNK;
let (parent, children) = focus.split_at_mut(children_index);
let (left, right) = children.split_at_mut(BYTES_PER_CHUNK);
if parent.is_empty() {
// NOTE: have to specially handle the situation where the children nodes and
// parent node share memory
hasher.update(&left);
hasher.update(right);
left.copy_from_slice(&hasher.finalize_reset());
} else {
// SAFETY: index is safe because parent.len() % BYTES_PER_CHUNK == 0 and
// parent isn't empty; qed
hash_nodes(&mut hasher, left, right, &mut parent[..BYTES_PER_CHUNK]);
}

// NOTE: we do not need mutability on `right` here so drop that capability
(parent, left, &*right)
}
Ordering::Equal => {
// SAFETY: index is safe because i*BYTES_PER_CHUNK < layer.len():
// i*BYTES_PER_CHUNK < (i+1)*BYTES_PER_CHUNK < layer.len()
// (see previous case); qed
let focus =
&mut layer[parent_index * BYTES_PER_CHUNK..(i + 1) * BYTES_PER_CHUNK];
// SAFETY: checked subtraction is unnecessary:
// focus.len() = (i + 1 - parent_index) * BYTES_PER_CHUNK
// and
// i >= parent_index
// so focus.len() >= BYTES_PER_CHUNK; qed
let children_index = focus.len() - BYTES_PER_CHUNK;
let (parent, children) = focus.split_at_mut(children_index);
let (left, _) = children.split_at_mut(BYTES_PER_CHUNK);
// NOTE: left.len() == BYTES_PER_CHUNK
let (parent, left) = focus.split_at_mut(children_index);
// SAFETY: checked subtraction is unnecessary:
// k <= height - 1
// so depth >= height - (height - 1) - 1
// = 0; qed
let depth = height - k - 1;
// SAFETY: index is safe because depth < CONTEXT.len():
// depth <= height - 1 == leaf_count.trailing_zeros()
// leaf_count.trailing_zeros() < MAX_MERKLE_TREE_DEPTH == CONTEXT.len(); qed
let right = &CONTEXT[depth as usize];
if parent.is_empty() {
// NOTE: have to specially handle the situation where the children nodes and
// parent node share memory
hasher.update(&left);
hasher.update(right);
left.copy_from_slice(&hasher.finalize_reset());
} else {
// SAFETY: index is safe because parent.len() % BYTES_PER_CHUNK == 0 and
// parent isn't empty; qed
hash_nodes(&mut hasher, left, right, &mut parent[..BYTES_PER_CHUNK]);
}
(parent, left, right)
}
_ => break,
};
if i == 0 {
// NOTE: nodes share memory here and so we can't use the `hash_nodes` utility
// as the disjunct nature is reflect in that functions type signature
// so instead we will just replicate here.
hasher.update(&left);
hasher.update(right);
left.copy_from_slice(&hasher.finalize_reset());
} else {
// SAFETY: index is safe because parent.len() % BYTES_PER_CHUNK == 0 and
// parent isn't empty; qed
hash_nodes(&mut hasher, left, right, &mut parent[..BYTES_PER_CHUNK]);
}
}
last_index /= 2;
}
Expand Down Expand Up @@ -288,30 +306,35 @@ mod tests {
debug_assert!(chunks.len() % BYTES_PER_CHUNK == 0);
debug_assert!(leaf_count.next_power_of_two() == leaf_count);

// SAFETY: checked subtraction is unnecessary,
// as leaf_count != 0 (0.next_power_of_two() == 1); qed
let node_count = 2 * leaf_count - 1;
// SAFETY: checked subtraction is unnecessary, as node_count >= leaf_count; qed
let interior_count = node_count - leaf_count;
let leaf_start = interior_count * BYTES_PER_CHUNK;

let mut hasher = Sha256::new();
let mut buffer = vec![0u8; node_count * BYTES_PER_CHUNK];
buffer[leaf_start..leaf_start + chunks.len()].copy_from_slice(chunks);
let zero_chunk = [0u8; 32];
for i in chunks.len()..leaf_count {
let start = leaf_start + (i * BYTES_PER_CHUNK);
let end = leaf_start + (i + 1) * BYTES_PER_CHUNK;
buffer[start..end].copy_from_slice(&zero_chunk);
}

for i in (1..node_count).rev().step_by(2) {
// SAFETY: checked subtraction is unnecessary, as i >= 1; qed
let parent_index = (i - 1) / 2;
let focus = &mut buffer[parent_index * BYTES_PER_CHUNK..(i + 1) * BYTES_PER_CHUNK];
// SAFETY: checked subtraction is unnecessary:
// focus.len() = (i + 1 - parent_index) * BYTES_PER_CHUNK
// = ((2*i + 2 - i + 1) / 2) * BYTES_PER_CHUNK
// = ((i + 3) / 2) * BYTES_PER_CHUNK
// and
// i >= 1
// so focus.len() >= 2 * BYTES_PER_CHUNK; qed
let children_index = focus.len() - 2 * BYTES_PER_CHUNK;
// NOTE: children.len() == 2 * BYTES_PER_CHUNK
let (parent, children) = focus.split_at_mut(children_index);
let left = &children[0..BYTES_PER_CHUNK];
let right = &children[BYTES_PER_CHUNK..2 * BYTES_PER_CHUNK];
let (left, right) = children.split_at(BYTES_PER_CHUNK);
hash_nodes(&mut hasher, left, right, &mut parent[..BYTES_PER_CHUNK]);
}
Ok(buffer[0..BYTES_PER_CHUNK].try_into().expect("can produce a single root chunk"))
Ok(buffer[..BYTES_PER_CHUNK].try_into().expect("can produce a single root chunk"))
}

#[test]
Expand Down

0 comments on commit d72b6f9

Please sign in to comment.