diff --git a/vortex-buffer/src/bit/buf_mut.rs b/vortex-buffer/src/bit/buf_mut.rs index c0ff973feec..5235e6e92f6 100644 --- a/vortex-buffer/src/bit/buf_mut.rs +++ b/vortex-buffer/src/bit/buf_mut.rs @@ -221,6 +221,16 @@ impl BitBufferMut { unsafe { unset_bit_unchecked(self.buffer.as_mut_ptr(), self.offset + index) } } + /// Foces the length of the `BitBufferMut` to `new_len`. + /// + /// # Safety + /// + /// - `new_len` must be less than or equal to [`capacity()`](Self::capacity) + /// - The elements at `old_len..new_len` must be initialized + pub unsafe fn set_len(&mut self, new_len: usize) { + self.len = new_len; + } + /// Truncate the buffer to the given length. pub fn truncate(&mut self, len: usize) { if len > self.len { @@ -241,18 +251,35 @@ impl BitBufferMut { } } + /// Append a new boolean into the bit buffer without checking for sufficient capacity. + /// + /// # Safety + /// + /// The caller must ensure there is sufficient capacity in the underlying byte buffer to + /// accommodate the new bit. If the bit position requires a new byte to be allocated, the buffer + /// must have capacity for at least one more byte. + pub unsafe fn append_unchecked(&mut self, value: bool) { + if value { + // SAFETY: checked by caller. + unsafe { self.append_true_unchecked() } + } else { + // SAFETY: checked by caller. + unsafe { self.append_false_unchecked() } + } + } + /// Append a new true value to the buffer. pub fn append_true(&mut self) { let bit_pos = self.offset + self.len; let byte_pos = bit_pos / 8; let bit_in_byte = bit_pos % 8; - // Ensure buffer has enough bytes + // Ensure buffer has enough bytes. if byte_pos >= self.buffer.len() { self.buffer.push(0u8); } - // Set the bit + // Set the bit. self.buffer.as_mut_slice()[byte_pos] |= 1 << bit_in_byte; self.len += 1; } @@ -263,12 +290,61 @@ impl BitBufferMut { let byte_pos = bit_pos / 8; let bit_in_byte = bit_pos % 8; - // Ensure buffer has enough bytes + // Ensure buffer has enough bytes. if byte_pos >= self.buffer.len() { self.buffer.push(0u8); } - // Bit is already 0 if we just pushed a new byte, otherwise ensure it's unset + // Bit is already 0 if we just pushed a new byte, otherwise ensure it's unset. + if bit_in_byte != 0 { + self.buffer.as_mut_slice()[byte_pos] &= !(1 << bit_in_byte); + } + + self.len += 1; + } + + /// Append a new true value to the buffer without checking for sufficient capacity. + /// + /// # Safety + /// + /// The caller must ensure there is sufficient capacity in the underlying byte buffer to + /// accommodate the new bit. If the bit position requires a new byte to be allocated, the buffer + /// must have capacity for at least one more byte. + pub unsafe fn append_true_unchecked(&mut self) { + let bit_pos = self.offset + self.len; + let byte_pos = bit_pos / 8; + let bit_in_byte = bit_pos % 8; + + // Ensure buffer has enough bytes. + if byte_pos >= self.buffer.len() { + // SAFETY: caller ensures sufficient capacity. + unsafe { self.buffer.push_unchecked(0u8) }; + } + + // Set the bit. + self.buffer.as_mut_slice()[byte_pos] |= 1 << bit_in_byte; + self.len += 1; + } + + /// Append a new false value to the buffer without checking for sufficient capacity. + /// + /// # Safety + /// + /// The caller must ensure there is sufficient capacity in the underlying byte buffer to + /// accommodate the new bit. If the bit position requires a new byte to be allocated, the buffer + /// must have capacity for at least one more byte. + pub unsafe fn append_false_unchecked(&mut self) { + let bit_pos = self.offset + self.len; + let byte_pos = bit_pos / 8; + let bit_in_byte = bit_pos % 8; + + // Ensure buffer has enough bytes. + if byte_pos >= self.buffer.len() { + // SAFETY: caller ensures sufficient capacity. + unsafe { self.buffer.push_unchecked(0u8) }; + } + + // Bit is already 0 if we just pushed a new byte, otherwise ensure it's unset. if bit_in_byte != 0 { self.buffer.as_mut_slice()[byte_pos] &= !(1 << bit_in_byte); } @@ -483,24 +559,35 @@ impl From> for BitBufferMut { impl FromIterator for BitBufferMut { fn from_iter>(iter: T) -> Self { - let iter = iter.into_iter(); - let (low, high) = iter.size_hint(); - if let Some(len) = high { - let mut buf = BitBufferMut::new_unset(len); - for (i, v) in iter.enumerate() { - if v { - // SAFETY: i is in bounds - unsafe { buf.set_unchecked(i) } - } - } - buf - } else { - let mut buf = BitBufferMut::with_capacity(low); - for v in iter { - buf.append(v); + let mut iter = iter.into_iter(); + + // Note that these hints might be incorrect. + let (lower_bound, upper_bound_opt) = iter.size_hint(); + let capacity = upper_bound_opt.unwrap_or(lower_bound); + + let mut buf = BitBufferMut::new_unset(capacity); + + // Directly write within our known capacity. + for i in 0..capacity { + let Some(v) = iter.next() else { + // SAFETY: We are definitely under the capacity and all values are already + // initialized from `new_unset`. + unsafe { buf.set_len(i) }; + return buf; + }; + + if v { + // SAFETY: We have ensured that we are within the capacity. + unsafe { buf.set_unchecked(i) } } - buf } + + // Append the remaining items (as we do not know how many more there are). + for v in iter { + buf.append(v); + } + + buf } } @@ -918,4 +1005,117 @@ mod tests { assert_eq!(frozen.offset(), 3); assert_eq!(frozen.len(), 6); } + + #[test] + fn test_append_unchecked() { + // Test that append_unchecked works correctly when there's sufficient capacity. + let mut bit_buf = BitBufferMut::with_capacity(100); + + // Reserve enough space for our operations. + bit_buf.reserve(50); + + // Use append_unchecked to add various patterns. + unsafe { + bit_buf.append_unchecked(true); + bit_buf.append_unchecked(false); + bit_buf.append_unchecked(true); + bit_buf.append_unchecked(true); + bit_buf.append_unchecked(false); + } + + assert_eq!(bit_buf.len(), 5); + assert!(bit_buf.value(0)); + assert!(!bit_buf.value(1)); + assert!(bit_buf.value(2)); + assert!(bit_buf.value(3)); + assert!(!bit_buf.value(4)); + + // Test appending across byte boundaries. + unsafe { + // Add bits to fill first byte. + bit_buf.append_unchecked(false); + bit_buf.append_unchecked(true); + bit_buf.append_unchecked(false); + // Now at 8 bits (full byte). + + // Add more bits into second byte. + bit_buf.append_unchecked(true); + bit_buf.append_unchecked(true); + bit_buf.append_unchecked(false); + } + + assert_eq!(bit_buf.len(), 11); + assert!(bit_buf.value(8)); // First bit of second byte. + assert!(bit_buf.value(9)); + assert!(!bit_buf.value(10)); + + // Test with offset. + let buf = BufferMut::zeroed(4); + let mut bit_buf_with_offset = BitBufferMut::from_buffer(buf, 3, 0); + bit_buf_with_offset.reserve(20); + + unsafe { + bit_buf_with_offset.append_unchecked(true); + bit_buf_with_offset.append_unchecked(false); + bit_buf_with_offset.append_unchecked(true); + } + + assert_eq!(bit_buf_with_offset.len(), 3); + assert!(bit_buf_with_offset.value(0)); + assert!(!bit_buf_with_offset.value(1)); + assert!(bit_buf_with_offset.value(2)); + } + + #[test] + fn test_from_iterator_with_incorrect_size_hint() { + // This test catches a bug where FromIterator assumed the upper bound + // from size_hint was accurate. The iterator contract allows the actual + // count to exceed the upper bound, which could cause UB if we used + // append_unchecked beyond the allocated capacity. + + // Custom iterator that lies about its size hint. + struct LyingIterator { + values: Vec, + index: usize, + } + + impl Iterator for LyingIterator { + type Item = bool; + + fn next(&mut self) -> Option { + (self.index < self.values.len()).then(|| { + let val = self.values[self.index]; + self.index += 1; + val + }) + } + + fn size_hint(&self) -> (usize, Option) { + // Deliberately return an incorrect upper bound that's smaller + // than the actual number of elements we'll yield. + let remaining = self.values.len() - self.index; + let lower = remaining.min(5); // Correct lower bound (but capped). + let upper = Some(5); // Incorrect upper bound - we actually have more! + (lower, upper) + } + } + + // Create an iterator that claims to have at most 5 elements but actually has 10. + let lying_iter = LyingIterator { + values: vec![ + true, false, true, false, true, false, true, false, true, false, + ], + index: 0, + }; + + // Collect the iterator. This would cause UB in the old implementation + // if it trusted the upper bound and used append_unchecked beyond capacity. + let bit_buf: BitBufferMut = lying_iter.collect(); + + // Verify all 10 elements were collected correctly. + assert_eq!(bit_buf.len(), 10); + for i in 0..10 { + assert_eq!(bit_buf.value(i), i % 2 == 0); + } + } } diff --git a/vortex-buffer/src/buffer_mut.rs b/vortex-buffer/src/buffer_mut.rs index c28f7b7150f..8acb6cda6c1 100644 --- a/vortex-buffer/src/buffer_mut.rs +++ b/vortex-buffer/src/buffer_mut.rs @@ -477,59 +477,105 @@ impl AsMut<[T]> for BufferMut { } impl BufferMut { + /// A helper method for the two [`Extend`] implementations. + /// + /// We use the lower bound hint on the iterator to manually write data, and then we continue to + /// push items normally past the lower bound. fn extend_iter(&mut self, mut iter: impl Iterator) { - // Attempt to reserve enough memory up-front, although this is only a lower bound. - let (lower, _) = iter.size_hint(); - self.reserve(lower); + // Since we do not know the length of the iterator, we can only guess how much memory we + // need to reserve. Note that these hints may be inaccurate. + let (lower_bound, upper_bound_opt) = iter.size_hint(); + self.reserve(upper_bound_opt.unwrap_or(lower_bound)); - let remaining = self.capacity() - self.len(); + let unwritten = self.capacity() - self.len(); + // We store `begin` in the case that the lower bound hint is incorrect. let begin: *const T = self.bytes.spare_capacity_mut().as_mut_ptr().cast(); let mut dst: *mut T = begin.cast_mut(); - for _ in 0..remaining { - if let Some(item) = iter.next() { - unsafe { - // SAFETY: We know we have enough capacity to write the item. - dst.write(item); - // Note. we used to have dst.add(iteration).write(item), here. - // however this was much slower than just incrementing dst. - dst = dst.add(1); - } - } else { + + // As a first step, we manually iterate the iterator up to the known capacity. + for _ in 0..unwritten { + let Some(item) = iter.next() else { + // The lower bound hint may be incorrect. break; - } + }; + + // SAFETY: We have reserved enough capacity to hold this item, and `dst` is a pointer + // derived from a valid reference to byte data. + unsafe { dst.write(item) }; + + // Note: We used to have `dst.add(iteration).write(item)`, here. However this was much + // slower than just incrementing `dst`. + // SAFETY: The offsets fits in `isize`, and because we were able to reserve the memory + // we know that `add` will not overflow. + unsafe { dst = dst.add(1) }; } - // TODO(joe): replace with ptr_sub when stable - let length = self.len() + unsafe { dst.byte_offset_from(begin) as usize / size_of::() }; + // SAFETY: `dst` was derived from `begin`, which were both valid references to byte data, + // and since the only operation that `dst` has is `add`, we know that `dst >= begin`. + let items_written = unsafe { dst.offset_from_unsigned(begin) }; + let length = self.len() + items_written; + + // SAFETY: We have written valid items between the old length and the new length. unsafe { self.set_len(length) }; - // Append remaining elements + // Finally, since the iterator will have arbitrarily more items to yield, we push the + // remaining items normally. iter.for_each(|item| self.push(item)); } - /// An unsafe variant of the `Extend` trait and its `extend` method that receives what the - /// caller guarantees to be an iterator with a trusted upper bound. + /// Extends the `BufferMut` with an iterator with `TrustedLen`. + /// + /// The caller guarantees that the iterator will have a trusted upper bound, which allows the + /// implementation to reserve all of the memory needed up front. pub fn extend_trusted>(&mut self, iter: I) { - // Reserve all memory upfront since it's an exact upper bound - let (_, high) = iter.size_hint(); - self.reserve(high.vortex_expect("TrustedLen iterator didn't have valid upper bound")); + // Since we know the exact upper bound (from `TrustedLen`), we can reserve all of the memory + // for this operation up front. + let (_, upper_bound) = iter.size_hint(); + self.reserve( + upper_bound + .vortex_expect("`TrustedLen` iterator somehow didn't have valid upper bound"), + ); + // We store `begin` in the case that the upper bound hint is incorrect. let begin: *const T = self.bytes.spare_capacity_mut().as_mut_ptr().cast(); let mut dst: *mut T = begin.cast_mut(); + iter.for_each(|item| { - unsafe { - // SAFETY: We know we have enough capacity to write the item. - dst.write(item); - // Note. we used to have dst.add(iteration).write(item), here. - // however this was much slower than just incrementing dst. - dst = dst.add(1); - } + // SAFETY: We have reserved enough capacity to hold this item, and `dst` is a pointer + // derived from a valid reference to byte data. + unsafe { dst.write(item) }; + + // Note: We used to have `dst.add(iteration).write(item)`, here. However this was much + // slower than just incrementing `dst`. + // SAFETY: The offsets fits in `isize`, and because we were able to reserve the memory + // we know that `add` will not overflow. + unsafe { dst = dst.add(1) }; }); - // TODO(joe): replace with ptr_sub when stable - let length = self.len() + unsafe { dst.byte_offset_from(begin) as usize / size_of::() }; + + // SAFETY: `dst` was derived from `begin`, which were both valid references to byte data, + // and since the only operation that `dst` has is `add`, we know that `dst >= begin`. + let items_written = unsafe { dst.offset_from_unsigned(begin) }; + let length = self.len() + items_written; + + // SAFETY: We have written valid items between the old length and the new length. unsafe { self.set_len(length) }; } + + /// Creates a `BufferMut` from an iterator with a trusted length. + /// + /// Internally, this calls [`extend_trusted()`](Self::extend_trusted). + pub fn from_trusted_len_iter(iter: I) -> Self + where + I: TrustedLen, + { + // We allow the `extend_trusted` method to correctly allocate the required memory. + let mut buffer = Self::with_capacity(0); + buffer.extend_trusted(iter); + + debug_assert_eq!(buffer.alignment(), Alignment::of::()); + buffer + } } impl Extend for BufferMut {