Skip to content

Commit

Permalink
safety check on write_fixed_bitset
Browse files Browse the repository at this point in the history
  • Loading branch information
ogabrielides committed Nov 4, 2024
1 parent fb58431 commit 603773d
Showing 1 changed file with 55 additions and 40 deletions.
95 changes: 55 additions & 40 deletions dash/src/consensus/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,12 @@ pub fn read_fixed_bitset<R: Read + ?Sized>(r: &mut R, size: usize) -> std::io::R
}

pub fn write_fixed_bitset<W: Write + ?Sized>(w: &mut W, bits: &[bool], size: usize) -> io::Result<usize> {
if bits.len() < size {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Bits length is less than the specified size",
));
}
// Define a reasonable maximum size to prevent excessive memory allocation
const MAX_BITSET_SIZE: usize = 1_000_000;
if size > MAX_BITSET_SIZE {
Expand Down Expand Up @@ -1491,56 +1497,65 @@ mod tests {
#[test]
fn test_fixed_bitset_round_trip() {
let test_cases = vec![
(vec![], 0),
(vec![true, false, true, false, true, false, true, false], 8),
(vec![true; 10], 10),
(vec![false; 15], 15),
(vec![true, false, true], 16), // size greater than bits.len()
(vec![], 0, true), // (bits, size, expect_success)
(vec![true, false, true, false, true, false, true, false], 8, true),
(vec![true; 10], 10, true),
(vec![false; 15], 15, true),
(vec![true, false, true], 16, false), // size greater than bits.len()
(
vec![
true, false, true, false, true, false, true, false, true, false, true, false,
true, false, true, false, true, false, true, false, true, false, true, false,
],
24,
true,
),
];

for (bits, size) in test_cases {
for (bits, size, expect_success) in test_cases {
let mut buffer = Vec::new();
// Write the bitset to the buffer
let bytes_written = write_fixed_bitset(&mut buffer, &bits, size).expect("Failed to write");
// Calculate expected bytes written
let expected_bytes = (size + 7) / 8;
assert_eq!(
bytes_written, expected_bytes,
"Incorrect number of bytes written for bitset with size {}",
size
);

// Read the bitset back from the buffer
let mut cursor = Cursor::new(&buffer);
let read_bits = read_fixed_bitset(&mut cursor, size).expect("Failed to read");

// Assert that the original bits match the deserialized bits
// For bits beyond bits.len(), they should be false
let expected_bits: Vec<bool> = (0..size)
.map(|i| bits.get(i).copied().unwrap_or(false))
.collect();

assert_eq!(
read_bits, expected_bits,
"Deserialized bits do not match original for size {}",
size
);

// Ensure that we've consumed all bytes (no extra bytes left)
let position = cursor.position();
assert_eq!(
position as usize,
buffer.len(),
"Not all bytes were consumed for size {}",
size
);
// Attempt to write the bitset to the buffer
let result = write_fixed_bitset(&mut buffer, &bits, size);

if expect_success {
// Expect the write to succeed
let bytes_written = result.expect("Failed to write");
// Calculate expected bytes written
let expected_bytes = (size + 7) / 8;
assert_eq!(
bytes_written, expected_bytes,
"Incorrect number of bytes written for bitset with size {}",
size
);

// Read the bitset back from the buffer
let mut cursor = Cursor::new(&buffer);
let read_bits = read_fixed_bitset(&mut cursor, size).expect("Failed to read");

// Assert that the original bits match the deserialized bits
assert_eq!(
read_bits, bits,
"Deserialized bits do not match original for size {}",
size
);

// Ensure that we've consumed all bytes (no extra bytes left)
let position = cursor.position();
assert_eq!(
position as usize,
buffer.len(),
"Not all bytes were consumed for size {}",
size
);
} else {
// Expect the write to fail
assert!(
result.is_err(),
"Expected write to fail for bits.len() < size (size: {}, bits.len(): {})",
size,
bits.len()
);
}
}
}
}

0 comments on commit 603773d

Please sign in to comment.