Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Note possible checked_sub cases #57

Merged
merged 25 commits into from
Jun 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
fb0616b
Add TODOs for checked_sub locations
doubledup May 10, 2023
404f368
Remove extra split_at_mut
doubledup May 17, 2023
4c0b8f0
Replace slicing with split_at
doubledup May 17, 2023
c7c467e
Replace magic number with BYTES_PER_CHUNK
doubledup May 17, 2023
479b76b
Measure the length of an incomplete last chunk
doubledup May 17, 2023
6e90e37
Handle checked_sub cases
doubledup May 15, 2023
79a7c71
Reformat checked subtraction comments
doubledup May 29, 2023
39d20fa
Clarify comment
doubledup Jun 6, 2023
4d1999a
Add test case for empty last byte in Bitlist
doubledup Jun 7, 2023
536f46a
Add SAFETY prefix to comments
doubledup Jun 7, 2023
392cd26
Replace assert! with OffsetNotIncreasing error
doubledup Jun 7, 2023
b231461
Add notes about slice lengths after splitting
doubledup Jun 7, 2023
8b9f952
Merge branch 'main' into checked-sub
doubledup Jun 7, 2023
c48e8ff
Merge branch 'main' into checked-sub
doubledup Jun 8, 2023
9bf8905
Swap InvalidByte for InstanceError::Bounded
doubledup Jun 8, 2023
2ad9a4b
rustfmt
doubledup Jun 8, 2023
f541e76
Only push increasing offsets
doubledup Jun 9, 2023
8e8add5
Merge branch 'main' into checked-sub
doubledup Jun 9, 2023
e303981
Update ssz-rs-derive/src/lib.rs
ralexstokes Jun 9, 2023
3efa277
Update ssz-rs/src/merkleization/mod.rs
ralexstokes Jun 9, 2023
220b6c4
Update ssz-rs/src/merkleization/mod.rs
ralexstokes Jun 9, 2023
946b442
simplify naive merkle hashing code
ralexstokes Jun 9, 2023
6accced
formatting: code org
ralexstokes Jun 10, 2023
9400cc4
add some clarifying docs
ralexstokes Jun 10, 2023
8bde6b5
factor out hashing from node juggling
ralexstokes Jun 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions ssz-rs-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,14 +193,25 @@ fn derive_deserialize_impl(data: &Data) -> TokenStream {
Some(field_name) => quote_spanned! { f.span() =>
let bytes_read = if <#field_type>::is_variable_size() {
let end = start + #BYTES_PER_LENGTH_OFFSET;

let target = encoding.get(start..end).ok_or_else(||
ssz_rs::DeserializeError::ExpectedFurtherInput {
provided: encoding.len() - start,
expected: #BYTES_PER_LENGTH_OFFSET,
}
)?;
let next_offset = u32::deserialize(target)?;
offsets.push((#i, next_offset as usize));
let next_offset = u32::deserialize(target)? as usize;

if let Some((_, previous_offset)) = offsets.last() {
if next_offset < *previous_offset {
return Err(DeserializeError::OffsetNotIncreasing {
start: *previous_offset,
end: next_offset,
})
}
}

offsets.push((#i, next_offset));

#BYTES_PER_LENGTH_OFFSET
} else {
Expand Down Expand Up @@ -248,6 +259,9 @@ fn derive_deserialize_impl(data: &Data) -> TokenStream {
}
)?;
container.__ssz_rs_set_by_index(index, target)?;

// SAFETY: checked subtraction is unnecessary,
// as offsets are increasing; qed
total_bytes_read += end - start;
}

Expand Down
14 changes: 14 additions & 0 deletions ssz-rs/src/bitlist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ impl<const N: usize> fmt::Debug for Bitlist<N> {
let value = i32::from(*bit);
write!(f, "{value}")?;
bits_written += 1;
// SAFETY: checked subtraction is unnecessary, as len >= 1 when this for loop runs; qed
if bits_written % 4 == 0 && index != len - 1 {
write!(f, "_")?;
}
Expand Down Expand Up @@ -110,6 +111,7 @@ impl<const N: usize> Bitlist<N> {
*last |= 1u8 << marker_index;
}
}
// SAFETY: checked subtraction is unnecessary, as buffer.len() > start_len; qed
Ok(buffer.len() - start_len)
}
}
Expand Down Expand Up @@ -161,10 +163,17 @@ impl<const N: usize> Deserialize for Bitlist<N> {
}

let (last_byte, prefix) = encoding.split_last().unwrap();
if *last_byte == 0u8 {
ralexstokes marked this conversation as resolved.
Show resolved Hide resolved
return Err(DeserializeError::InvalidByte(*last_byte))
}

let mut result = BitlistInner::from_slice(prefix);
let last = BitlistInner::from_element(*last_byte);

// validate bit length satisfies bound `N`
// SAFETY: checked subtraction is unnecessary,
// as last_byte != 0, so last.trailing_zeros <= 7; qed
// therefore: bit_length >= 1
let bit_length = 8 - last.trailing_zeros();
let additional_members = bit_length - 1; // skip marker bit
let total_members = result.len() + additional_members;
Expand Down Expand Up @@ -282,6 +291,11 @@ mod tests {
)
.unwrap();
assert_eq!(result, expected);

let bytes = vec![24u8, 0u8];
let result = Bitlist::<COUNT>::deserialize(&bytes).expect_err("test data is incorrect");
let expected = DeserializeError::InvalidByte(0u8);
assert_eq!(result.to_string(), expected.to_string());
ralexstokes marked this conversation as resolved.
Show resolved Hide resolved
}

#[test]
Expand Down
1 change: 1 addition & 0 deletions ssz-rs/src/bitvector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ impl<const N: usize> fmt::Debug for Bitvector<N> {
let value = i32::from(*bit);
write!(f, "{value}")?;
bits_written += 1;
// SAFETY: checked subtraction is unnecessary, as len >= 1 for bitvectors; qed
if bits_written % 4 == 0 && index != len - 1 {
write!(f, "_")?;
}
Expand Down
1 change: 1 addition & 0 deletions ssz-rs/src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ where
if remainder != 0 {
return Err(DeserializeError::AdditionalInput {
provided: encoding.len(),
// SAFETY: checked subtraction is unnecessary, as encoding.len() > remainder; qed
expected: encoding.len() - remainder,
})
}
Expand Down
107 changes: 65 additions & 42 deletions ssz-rs/src/merkleization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,15 @@ impl Display for MerkleizationError {
#[cfg(feature = "std")]
impl std::error::Error for MerkleizationError {}

// Ensures `buffer` can be exactly broken up into `BYTES_PER_CHUNK` chunks of bytes
// via padding any partial chunks at the end of `buffer`
pub fn pack_bytes(buffer: &mut Vec<u8>) {
ralexstokes marked this conversation as resolved.
Show resolved Hide resolved
let data_len = buffer.len();
if data_len % BYTES_PER_CHUNK != 0 {
let bytes_to_pad = BYTES_PER_CHUNK - data_len % BYTES_PER_CHUNK;
let pad = vec![0u8; bytes_to_pad];
buffer.extend_from_slice(&pad);
let incomplete_chunk_len = buffer.len() % BYTES_PER_CHUNK;
if incomplete_chunk_len != 0 {
// SAFETY: checked subtraction is unnecessary,
// as BYTES_PER_CHUNK > incomplete_chunk_len; qed
let bytes_to_pad = BYTES_PER_CHUNK - incomplete_chunk_len;
buffer.resize(buffer.len() + bytes_to_pad, 0);
}
}

Expand Down Expand Up @@ -99,6 +102,10 @@ include!(concat!(env!("OUT_DIR"), "/context.rs"));
/// of two and this can be quite large for some types. "Zero" subtrees are virtualized to avoid the
/// memory and computation cost of large trees with partially empty leaves.
///
/// The implementation approach treats `chunks` as the bottom layer of a perfect binary tree
/// and for each height performs the hashing required to compute the parent layer in place.
/// This process is repated until the root is computed.
///
/// Invariant: `chunks.len() % BYTES_PER_CHUNK == 0`
/// Invariant: `leaf_count.next_power_of_two() == leaf_count`
/// Invariant: `leaf_count != 0`
Expand All @@ -107,80 +114,91 @@ fn merkleize_chunks_with_virtual_padding(
chunks: &[u8],
leaf_count: usize,
) -> Result<Node, MerkleizationError> {
let chunk_count = chunks.len() / BYTES_PER_CHUNK;

let mut hasher = Sha256::new();
debug_assert!(chunks.len() % BYTES_PER_CHUNK == 0);
// NOTE: This also asserts that leaf_count != 0
debug_assert!(leaf_count.next_power_of_two() == leaf_count);
// SAFETY: this holds as long as leaf_count != 0 and usize is no longer than u64
debug_assert!((leaf_count.trailing_zeros() as usize) < MAX_MERKLE_TREE_DEPTH);

let chunk_count = chunks.len() / BYTES_PER_CHUNK;
let height = leaf_count.trailing_zeros() + 1;

if chunk_count == 0 {
// SAFETY: checked subtraction is unnecessary, as height >= 1; qed
let depth = height - 1;
// SAFETY: index is safe while depth == leaf_count.trailing_zeros() < MAX_MERKLE_TREE_DEPTH;
// qed
return Ok(CONTEXT[depth as usize].try_into().expect("can produce a single root chunk"))
}

let mut layer = chunks.to_vec();
// SAFETY: checked subtraction is unnecessary, as we return early when chunk_count == 0; qed
let mut last_index = chunk_count - 1;
let mut hasher = Sha256::new();
// for each layer of the tree, starting from the bottom and walking up to the root:
for k in (1..height).rev() {
// for each pair of nodes in this layer:
for i in (0..2usize.pow(k)).step_by(2) {
let parent_index = i / 2;
match i.cmp(&last_index) {
let (parent, left, right) = match i.cmp(&last_index) {
Ordering::Less => {
// SAFETY: index is safe because (i+1)*BYTES_PER_CHUNK < layer.len():
// i < last_index == chunk_count - 1 == (layer.len() / BYTES_PER_CHUNK) - 1
// so i+1 < layer.len() / BYTES_PER_CHUNK
// so (i+1)*BYTES_PER_CHUNK < layer.len(); qed
let focus =
&mut layer[parent_index * BYTES_PER_CHUNK..(i + 2) * BYTES_PER_CHUNK];
// SAFETY: checked subtraction is unnecessary:
// focus.len() = (i + 2 - parent_index) * BYTES_PER_CHUNK
// and
// i >= parent_index
// so focus.len() >= 2 * BYTES_PER_CHUNK; qed
let children_index = focus.len() - 2 * BYTES_PER_CHUNK;
let (parent, children) = focus.split_at_mut(children_index);
let (left, right) = children.split_at_mut(BYTES_PER_CHUNK);
if parent.is_empty() {
// NOTE: have to specially handle the situation where the children nodes and
// parent node share memory
hasher.update(&left);
hasher.update(right);
left.copy_from_slice(&hasher.finalize_reset());
} else {
// SAFETY: index is safe because parent.len() % BYTES_PER_CHUNK == 0 and
// parent isn't empty; qed
hash_nodes(&mut hasher, left, right, &mut parent[..BYTES_PER_CHUNK]);
}

// NOTE: we do not need mutability on `right` here so drop that capability
(parent, left, &*right)
}
Ordering::Equal => {
// SAFETY: index is safe because i*BYTES_PER_CHUNK < layer.len():
// i*BYTES_PER_CHUNK < (i+1)*BYTES_PER_CHUNK < layer.len()
// (see previous case); qed
let focus =
&mut layer[parent_index * BYTES_PER_CHUNK..(i + 1) * BYTES_PER_CHUNK];
// SAFETY: checked subtraction is unnecessary:
// focus.len() = (i + 1 - parent_index) * BYTES_PER_CHUNK
// and
// i >= parent_index
// so focus.len() >= BYTES_PER_CHUNK; qed
let children_index = focus.len() - BYTES_PER_CHUNK;
let (parent, children) = focus.split_at_mut(children_index);
let (left, _) = children.split_at_mut(BYTES_PER_CHUNK);
// NOTE: left.len() == BYTES_PER_CHUNK
let (parent, left) = focus.split_at_mut(children_index);
ralexstokes marked this conversation as resolved.
Show resolved Hide resolved
// SAFETY: checked subtraction is unnecessary:
// k <= height - 1
// so depth >= height - (height - 1) - 1
// = 0; qed
let depth = height - k - 1;
// SAFETY: index is safe because depth < CONTEXT.len():
// depth <= height - 1 == leaf_count.trailing_zeros()
// leaf_count.trailing_zeros() < MAX_MERKLE_TREE_DEPTH == CONTEXT.len(); qed
let right = &CONTEXT[depth as usize];
if parent.is_empty() {
// NOTE: have to specially handle the situation where the children nodes and
// parent node share memory
hasher.update(&left);
hasher.update(right);
left.copy_from_slice(&hasher.finalize_reset());
} else {
// SAFETY: index is safe because parent.len() % BYTES_PER_CHUNK == 0 and
// parent isn't empty; qed
hash_nodes(&mut hasher, left, right, &mut parent[..BYTES_PER_CHUNK]);
}
(parent, left, right)
}
_ => break,
};
if i == 0 {
// NOTE: nodes share memory here and so we can't use the `hash_nodes` utility
// as the disjunct nature is reflect in that functions type signature
// so instead we will just replicate here.
hasher.update(&left);
hasher.update(right);
left.copy_from_slice(&hasher.finalize_reset());
} else {
// SAFETY: index is safe because parent.len() % BYTES_PER_CHUNK == 0 and
// parent isn't empty; qed
hash_nodes(&mut hasher, left, right, &mut parent[..BYTES_PER_CHUNK]);
}
}
last_index /= 2;
}
Expand Down Expand Up @@ -288,30 +306,35 @@ mod tests {
debug_assert!(chunks.len() % BYTES_PER_CHUNK == 0);
debug_assert!(leaf_count.next_power_of_two() == leaf_count);

// SAFETY: checked subtraction is unnecessary,
// as leaf_count != 0 (0.next_power_of_two() == 1); qed
let node_count = 2 * leaf_count - 1;
// SAFETY: checked subtraction is unnecessary, as node_count >= leaf_count; qed
let interior_count = node_count - leaf_count;
let leaf_start = interior_count * BYTES_PER_CHUNK;

let mut hasher = Sha256::new();
let mut buffer = vec![0u8; node_count * BYTES_PER_CHUNK];
buffer[leaf_start..leaf_start + chunks.len()].copy_from_slice(chunks);
let zero_chunk = [0u8; 32];
for i in chunks.len()..leaf_count {
let start = leaf_start + (i * BYTES_PER_CHUNK);
let end = leaf_start + (i + 1) * BYTES_PER_CHUNK;
buffer[start..end].copy_from_slice(&zero_chunk);
}

for i in (1..node_count).rev().step_by(2) {
// SAFETY: checked subtraction is unnecessary, as i >= 1; qed
let parent_index = (i - 1) / 2;
let focus = &mut buffer[parent_index * BYTES_PER_CHUNK..(i + 1) * BYTES_PER_CHUNK];
// SAFETY: checked subtraction is unnecessary:
// focus.len() = (i + 1 - parent_index) * BYTES_PER_CHUNK
// = ((2*i + 2 - i + 1) / 2) * BYTES_PER_CHUNK
// = ((i + 3) / 2) * BYTES_PER_CHUNK
// and
// i >= 1
// so focus.len() >= 2 * BYTES_PER_CHUNK; qed
let children_index = focus.len() - 2 * BYTES_PER_CHUNK;
// NOTE: children.len() == 2 * BYTES_PER_CHUNK
let (parent, children) = focus.split_at_mut(children_index);
let left = &children[0..BYTES_PER_CHUNK];
let right = &children[BYTES_PER_CHUNK..2 * BYTES_PER_CHUNK];
let (left, right) = children.split_at(BYTES_PER_CHUNK);
ralexstokes marked this conversation as resolved.
Show resolved Hide resolved
hash_nodes(&mut hasher, left, right, &mut parent[..BYTES_PER_CHUNK]);
}
Ok(buffer[0..BYTES_PER_CHUNK].try_into().expect("can produce a single root chunk"))
Ok(buffer[..BYTES_PER_CHUNK].try_into().expect("can produce a single root chunk"))
}

#[test]
Expand Down