diff --git a/vortex-buffer/src/bit/buf_mut.rs b/vortex-buffer/src/bit/buf_mut.rs index c0ff973feec..9cfb53bd510 100644 --- a/vortex-buffer/src/bit/buf_mut.rs +++ b/vortex-buffer/src/bit/buf_mut.rs @@ -221,6 +221,20 @@ impl BitBufferMut { unsafe { unset_bit_unchecked(self.buffer.as_mut_ptr(), self.offset + index) } } + /// Foces the length of the `BitBufferMut` to `new_len`. + /// + /// # Safety + /// + /// - `new_len` must be less than or equal to [`capacity()`](Self::capacity) + /// - The elements at `old_len..new_len` must be initialized + pub unsafe fn set_len(&mut self, new_len: usize) { + debug_assert!( + new_len <= self.capacity(), + "`set_len` requires that new_len <= capacity()" + ); + self.len = new_len; + } + /// Truncate the buffer to the given length. pub fn truncate(&mut self, len: usize) { if len > self.len { @@ -483,24 +497,35 @@ impl From> for BitBufferMut { impl FromIterator for BitBufferMut { fn from_iter>(iter: T) -> Self { - let iter = iter.into_iter(); - let (low, high) = iter.size_hint(); - if let Some(len) = high { - let mut buf = BitBufferMut::new_unset(len); - for (i, v) in iter.enumerate() { - if v { - // SAFETY: i is in bounds - unsafe { buf.set_unchecked(i) } - } - } - buf - } else { - let mut buf = BitBufferMut::with_capacity(low); - for v in iter { - buf.append(v); + let mut iter = iter.into_iter(); + + // Note that these hints might be incorrect. + let (lower_bound, upper_bound_opt) = iter.size_hint(); + let capacity = upper_bound_opt.unwrap_or(lower_bound); + + let mut buf = BitBufferMut::new_unset(capacity); + + // Directly write within our known capacity. + for i in 0..capacity { + let Some(v) = iter.next() else { + // SAFETY: We are definitely under the capacity and all values are already + // initialized from `new_unset`. + unsafe { buf.set_len(i) }; + return buf; + }; + + if v { + // SAFETY: We have ensured that we are within the capacity. + unsafe { buf.set_unchecked(i) } } - buf } + + // Append the remaining items (as we do not know how many more there are). + for v in iter { + buf.append(v); + } + + buf } } @@ -918,4 +943,57 @@ mod tests { assert_eq!(frozen.offset(), 3); assert_eq!(frozen.len(), 6); } + + #[test] + fn test_from_iterator_with_incorrect_size_hint() { + // This test catches a bug where FromIterator assumed the upper bound + // from size_hint was accurate. The iterator contract allows the actual + // count to exceed the upper bound, which could cause UB if we used + // append_unchecked beyond the allocated capacity. + + // Custom iterator that lies about its size hint. + struct LyingIterator { + values: Vec, + index: usize, + } + + impl Iterator for LyingIterator { + type Item = bool; + + fn next(&mut self) -> Option { + (self.index < self.values.len()).then(|| { + let val = self.values[self.index]; + self.index += 1; + val + }) + } + + fn size_hint(&self) -> (usize, Option) { + // Deliberately return an incorrect upper bound that's smaller + // than the actual number of elements we'll yield. + let remaining = self.values.len() - self.index; + let lower = remaining.min(5); // Correct lower bound (but capped). + let upper = Some(5); // Incorrect upper bound - we actually have more! + (lower, upper) + } + } + + // Create an iterator that claims to have at most 5 elements but actually has 10. + let lying_iter = LyingIterator { + values: vec![ + true, false, true, false, true, false, true, false, true, false, + ], + index: 0, + }; + + // Collect the iterator. This would cause UB in the old implementation + // if it trusted the upper bound and used append_unchecked beyond capacity. + let bit_buf: BitBufferMut = lying_iter.collect(); + + // Verify all 10 elements were collected correctly. + assert_eq!(bit_buf.len(), 10); + for i in 0..10 { + assert_eq!(bit_buf.value(i), i % 2 == 0); + } + } }