diff --git a/ssz-rs-derive/src/lib.rs b/ssz-rs-derive/src/lib.rs index 3306384b..f413897e 100644 --- a/ssz-rs-derive/src/lib.rs +++ b/ssz-rs-derive/src/lib.rs @@ -193,14 +193,25 @@ fn derive_deserialize_impl(data: &Data) -> TokenStream { Some(field_name) => quote_spanned! { f.span() => let bytes_read = if <#field_type>::is_variable_size() { let end = start + #BYTES_PER_LENGTH_OFFSET; + let target = encoding.get(start..end).ok_or_else(|| ssz_rs::DeserializeError::ExpectedFurtherInput { provided: encoding.len() - start, expected: #BYTES_PER_LENGTH_OFFSET, } )?; - let next_offset = u32::deserialize(target)?; - offsets.push((#i, next_offset as usize)); + let next_offset = u32::deserialize(target)? as usize; + + if let Some((_, previous_offset)) = offsets.last() { + if next_offset < *previous_offset { + return Err(DeserializeError::OffsetNotIncreasing { + start: *previous_offset, + end: next_offset, + }) + } + } + + offsets.push((#i, next_offset)); #BYTES_PER_LENGTH_OFFSET } else { @@ -248,6 +259,9 @@ fn derive_deserialize_impl(data: &Data) -> TokenStream { } )?; container.__ssz_rs_set_by_index(index, target)?; + + // SAFETY: checked subtraction is unnecessary, + // as offsets are increasing; qed total_bytes_read += end - start; } diff --git a/ssz-rs/src/bitlist.rs b/ssz-rs/src/bitlist.rs index 7a73decd..48dfc3d9 100644 --- a/ssz-rs/src/bitlist.rs +++ b/ssz-rs/src/bitlist.rs @@ -51,6 +51,7 @@ impl fmt::Debug for Bitlist { let value = i32::from(*bit); write!(f, "{value}")?; bits_written += 1; + // SAFETY: checked subtraction is unnecessary, as len >= 1 when this for loop runs; qed if bits_written % 4 == 0 && index != len - 1 { write!(f, "_")?; } @@ -110,6 +111,7 @@ impl Bitlist { *last |= 1u8 << marker_index; } } + // SAFETY: checked subtraction is unnecessary, as buffer.len() > start_len; qed Ok(buffer.len() - start_len) } } @@ -161,10 +163,17 @@ impl Deserialize for Bitlist { } let (last_byte, prefix) = encoding.split_last().unwrap(); + if *last_byte == 0u8 { + return Err(DeserializeError::InvalidByte(*last_byte)) + } + let mut result = BitlistInner::from_slice(prefix); let last = BitlistInner::from_element(*last_byte); // validate bit length satisfies bound `N` + // SAFETY: checked subtraction is unnecessary, + // as last_byte != 0, so last.trailing_zeros <= 7; qed + // therefore: bit_length >= 1 let bit_length = 8 - last.trailing_zeros(); let additional_members = bit_length - 1; // skip marker bit let total_members = result.len() + additional_members; @@ -282,6 +291,11 @@ mod tests { ) .unwrap(); assert_eq!(result, expected); + + let bytes = vec![24u8, 0u8]; + let result = Bitlist::::deserialize(&bytes).expect_err("test data is incorrect"); + let expected = DeserializeError::InvalidByte(0u8); + assert_eq!(result.to_string(), expected.to_string()); } #[test] diff --git a/ssz-rs/src/bitvector.rs b/ssz-rs/src/bitvector.rs index 37728099..91d56635 100644 --- a/ssz-rs/src/bitvector.rs +++ b/ssz-rs/src/bitvector.rs @@ -60,6 +60,7 @@ impl fmt::Debug for Bitvector { let value = i32::from(*bit); write!(f, "{value}")?; bits_written += 1; + // SAFETY: checked subtraction is unnecessary, as len >= 1 for bitvectors; qed if bits_written % 4 == 0 && index != len - 1 { write!(f, "_")?; } diff --git a/ssz-rs/src/de.rs b/ssz-rs/src/de.rs index f3f10c52..e4ea52db 100644 --- a/ssz-rs/src/de.rs +++ b/ssz-rs/src/de.rs @@ -80,6 +80,7 @@ where if remainder != 0 { return Err(DeserializeError::AdditionalInput { provided: encoding.len(), + // SAFETY: checked subtraction is unnecessary, as encoding.len() > remainder; qed expected: encoding.len() - remainder, }) } diff --git a/ssz-rs/src/merkleization/mod.rs b/ssz-rs/src/merkleization/mod.rs index 2ae30ab9..d3b4df96 100644 --- a/ssz-rs/src/merkleization/mod.rs +++ b/ssz-rs/src/merkleization/mod.rs @@ -43,12 +43,15 @@ impl Display for MerkleizationError { #[cfg(feature = "std")] impl std::error::Error for MerkleizationError {} +// Ensures `buffer` can be exactly broken up into `BYTES_PER_CHUNK` chunks of bytes +// via padding any partial chunks at the end of `buffer` pub fn pack_bytes(buffer: &mut Vec) { - let data_len = buffer.len(); - if data_len % BYTES_PER_CHUNK != 0 { - let bytes_to_pad = BYTES_PER_CHUNK - data_len % BYTES_PER_CHUNK; - let pad = vec![0u8; bytes_to_pad]; - buffer.extend_from_slice(&pad); + let incomplete_chunk_len = buffer.len() % BYTES_PER_CHUNK; + if incomplete_chunk_len != 0 { + // SAFETY: checked subtraction is unnecessary, + // as BYTES_PER_CHUNK > incomplete_chunk_len; qed + let bytes_to_pad = BYTES_PER_CHUNK - incomplete_chunk_len; + buffer.resize(buffer.len() + bytes_to_pad, 0); } } @@ -99,6 +102,10 @@ include!(concat!(env!("OUT_DIR"), "/context.rs")); /// of two and this can be quite large for some types. "Zero" subtrees are virtualized to avoid the /// memory and computation cost of large trees with partially empty leaves. /// +/// The implementation approach treats `chunks` as the bottom layer of a perfect binary tree +/// and for each height performs the hashing required to compute the parent layer in place. +/// This process is repated until the root is computed. +/// /// Invariant: `chunks.len() % BYTES_PER_CHUNK == 0` /// Invariant: `leaf_count.next_power_of_two() == leaf_count` /// Invariant: `leaf_count != 0` @@ -107,18 +114,17 @@ fn merkleize_chunks_with_virtual_padding( chunks: &[u8], leaf_count: usize, ) -> Result { - let chunk_count = chunks.len() / BYTES_PER_CHUNK; - - let mut hasher = Sha256::new(); debug_assert!(chunks.len() % BYTES_PER_CHUNK == 0); // NOTE: This also asserts that leaf_count != 0 debug_assert!(leaf_count.next_power_of_two() == leaf_count); // SAFETY: this holds as long as leaf_count != 0 and usize is no longer than u64 debug_assert!((leaf_count.trailing_zeros() as usize) < MAX_MERKLE_TREE_DEPTH); + let chunk_count = chunks.len() / BYTES_PER_CHUNK; let height = leaf_count.trailing_zeros() + 1; if chunk_count == 0 { + // SAFETY: checked subtraction is unnecessary, as height >= 1; qed let depth = height - 1; // SAFETY: index is safe while depth == leaf_count.trailing_zeros() < MAX_MERKLE_TREE_DEPTH; // qed @@ -126,11 +132,15 @@ fn merkleize_chunks_with_virtual_padding( } let mut layer = chunks.to_vec(); + // SAFETY: checked subtraction is unnecessary, as we return early when chunk_count == 0; qed let mut last_index = chunk_count - 1; + let mut hasher = Sha256::new(); + // for each layer of the tree, starting from the bottom and walking up to the root: for k in (1..height).rev() { + // for each pair of nodes in this layer: for i in (0..2usize.pow(k)).step_by(2) { let parent_index = i / 2; - match i.cmp(&last_index) { + let (parent, left, right) = match i.cmp(&last_index) { Ordering::Less => { // SAFETY: index is safe because (i+1)*BYTES_PER_CHUNK < layer.len(): // i < last_index == chunk_count - 1 == (layer.len() / BYTES_PER_CHUNK) - 1 @@ -138,20 +148,17 @@ fn merkleize_chunks_with_virtual_padding( // so (i+1)*BYTES_PER_CHUNK < layer.len(); qed let focus = &mut layer[parent_index * BYTES_PER_CHUNK..(i + 2) * BYTES_PER_CHUNK]; + // SAFETY: checked subtraction is unnecessary: + // focus.len() = (i + 2 - parent_index) * BYTES_PER_CHUNK + // and + // i >= parent_index + // so focus.len() >= 2 * BYTES_PER_CHUNK; qed let children_index = focus.len() - 2 * BYTES_PER_CHUNK; let (parent, children) = focus.split_at_mut(children_index); let (left, right) = children.split_at_mut(BYTES_PER_CHUNK); - if parent.is_empty() { - // NOTE: have to specially handle the situation where the children nodes and - // parent node share memory - hasher.update(&left); - hasher.update(right); - left.copy_from_slice(&hasher.finalize_reset()); - } else { - // SAFETY: index is safe because parent.len() % BYTES_PER_CHUNK == 0 and - // parent isn't empty; qed - hash_nodes(&mut hasher, left, right, &mut parent[..BYTES_PER_CHUNK]); - } + + // NOTE: we do not need mutability on `right` here so drop that capability + (parent, left, &*right) } Ordering::Equal => { // SAFETY: index is safe because i*BYTES_PER_CHUNK < layer.len(): @@ -159,28 +166,39 @@ fn merkleize_chunks_with_virtual_padding( // (see previous case); qed let focus = &mut layer[parent_index * BYTES_PER_CHUNK..(i + 1) * BYTES_PER_CHUNK]; + // SAFETY: checked subtraction is unnecessary: + // focus.len() = (i + 1 - parent_index) * BYTES_PER_CHUNK + // and + // i >= parent_index + // so focus.len() >= BYTES_PER_CHUNK; qed let children_index = focus.len() - BYTES_PER_CHUNK; - let (parent, children) = focus.split_at_mut(children_index); - let (left, _) = children.split_at_mut(BYTES_PER_CHUNK); + // NOTE: left.len() == BYTES_PER_CHUNK + let (parent, left) = focus.split_at_mut(children_index); + // SAFETY: checked subtraction is unnecessary: + // k <= height - 1 + // so depth >= height - (height - 1) - 1 + // = 0; qed let depth = height - k - 1; // SAFETY: index is safe because depth < CONTEXT.len(): // depth <= height - 1 == leaf_count.trailing_zeros() // leaf_count.trailing_zeros() < MAX_MERKLE_TREE_DEPTH == CONTEXT.len(); qed let right = &CONTEXT[depth as usize]; - if parent.is_empty() { - // NOTE: have to specially handle the situation where the children nodes and - // parent node share memory - hasher.update(&left); - hasher.update(right); - left.copy_from_slice(&hasher.finalize_reset()); - } else { - // SAFETY: index is safe because parent.len() % BYTES_PER_CHUNK == 0 and - // parent isn't empty; qed - hash_nodes(&mut hasher, left, right, &mut parent[..BYTES_PER_CHUNK]); - } + (parent, left, right) } _ => break, }; + if i == 0 { + // NOTE: nodes share memory here and so we can't use the `hash_nodes` utility + // as the disjunct nature is reflect in that functions type signature + // so instead we will just replicate here. + hasher.update(&left); + hasher.update(right); + left.copy_from_slice(&hasher.finalize_reset()); + } else { + // SAFETY: index is safe because parent.len() % BYTES_PER_CHUNK == 0 and + // parent isn't empty; qed + hash_nodes(&mut hasher, left, right, &mut parent[..BYTES_PER_CHUNK]); + } } last_index /= 2; } @@ -288,30 +306,35 @@ mod tests { debug_assert!(chunks.len() % BYTES_PER_CHUNK == 0); debug_assert!(leaf_count.next_power_of_two() == leaf_count); + // SAFETY: checked subtraction is unnecessary, + // as leaf_count != 0 (0.next_power_of_two() == 1); qed let node_count = 2 * leaf_count - 1; + // SAFETY: checked subtraction is unnecessary, as node_count >= leaf_count; qed let interior_count = node_count - leaf_count; let leaf_start = interior_count * BYTES_PER_CHUNK; let mut hasher = Sha256::new(); let mut buffer = vec![0u8; node_count * BYTES_PER_CHUNK]; buffer[leaf_start..leaf_start + chunks.len()].copy_from_slice(chunks); - let zero_chunk = [0u8; 32]; - for i in chunks.len()..leaf_count { - let start = leaf_start + (i * BYTES_PER_CHUNK); - let end = leaf_start + (i + 1) * BYTES_PER_CHUNK; - buffer[start..end].copy_from_slice(&zero_chunk); - } for i in (1..node_count).rev().step_by(2) { + // SAFETY: checked subtraction is unnecessary, as i >= 1; qed let parent_index = (i - 1) / 2; let focus = &mut buffer[parent_index * BYTES_PER_CHUNK..(i + 1) * BYTES_PER_CHUNK]; + // SAFETY: checked subtraction is unnecessary: + // focus.len() = (i + 1 - parent_index) * BYTES_PER_CHUNK + // = ((2*i + 2 - i + 1) / 2) * BYTES_PER_CHUNK + // = ((i + 3) / 2) * BYTES_PER_CHUNK + // and + // i >= 1 + // so focus.len() >= 2 * BYTES_PER_CHUNK; qed let children_index = focus.len() - 2 * BYTES_PER_CHUNK; + // NOTE: children.len() == 2 * BYTES_PER_CHUNK let (parent, children) = focus.split_at_mut(children_index); - let left = &children[0..BYTES_PER_CHUNK]; - let right = &children[BYTES_PER_CHUNK..2 * BYTES_PER_CHUNK]; + let (left, right) = children.split_at(BYTES_PER_CHUNK); hash_nodes(&mut hasher, left, right, &mut parent[..BYTES_PER_CHUNK]); } - Ok(buffer[0..BYTES_PER_CHUNK].try_into().expect("can produce a single root chunk")) + Ok(buffer[..BYTES_PER_CHUNK].try_into().expect("can produce a single root chunk")) } #[test]