diff --git a/arrow-buffer/src/util/bit_iterator.rs b/arrow-buffer/src/util/bit_iterator.rs index c7f6f94fb869..9cd9fd11d090 100644 --- a/arrow-buffer/src/util/bit_iterator.rs +++ b/arrow-buffer/src/util/bit_iterator.rs @@ -23,6 +23,7 @@ use crate::bit_util::{ceil, get_bit_raw}; /// Iterator over the bits within a packed bitmask /// /// To efficiently iterate over just the set bits see [`BitIndexIterator`] and [`BitSliceIterator`] +#[derive(Clone)] pub struct BitIterator<'a> { buffer: &'a [u8], current_offset: usize, @@ -71,6 +72,71 @@ impl Iterator for BitIterator<'_> { let remaining_bits = self.end_offset - self.current_offset; (remaining_bits, Some(remaining_bits)) } + + fn count(self) -> usize + where + Self: Sized, + { + self.len() + } + + fn nth(&mut self, n: usize) -> Option { + // Check if we advance to the one before the desired offset + // when n is 0 it means we want the next() value + // and when n is 1 we want the next().next() value + // so adding n to the current offset and not n - 1 + match self.current_offset.checked_add(n) { + // Yes, and still within bounds + Some(new_offset) if new_offset < self.end_offset => { + self.current_offset = new_offset; + } + + // Either overflow or would exceed end_offset + _ => { + self.current_offset = self.end_offset; + return None; + } + } + + self.next() + } + + fn last(mut self) -> Option { + // If already at the end, return None + if self.current_offset == self.end_offset { + return None; + } + + // Go to the one before the last bit + self.current_offset = self.end_offset - 1; + + // Return the last bit + self.next() + } + + fn max(self) -> Option + where + Self: Sized, + Self::Item: Ord, + { + if self.current_offset == self.end_offset { + return None; + } + + // true is greater than false so we only need to check if there's any true bit + let mut bit_index_iter = BitIndexIterator::new( + self.buffer, + self.current_offset, + self.end_offset - self.current_offset, + ); + + if bit_index_iter.next().is_some() { + return Some(true); + } + + // We know the iterator is not empty and there are no set bits so false is the max + Some(false) + } } impl ExactSizeIterator for BitIterator<'_> {} @@ -86,6 +152,27 @@ impl DoubleEndedIterator for BitIterator<'_> { let v = unsafe { get_bit_raw(self.buffer.as_ptr(), self.end_offset) }; Some(v) } + + fn nth_back(&mut self, n: usize) -> Option { + // Check if we advance to the one before the desired offset + // when n is 0 it means we want the next_back() value + // and when n is 1 we want the next_back().next_back() value + // so adding n to the current offset and not n - 1 + match self.end_offset.checked_sub(n) { + // Yes, and still within bounds + Some(new_offset) if self.current_offset < new_offset => { + self.end_offset = new_offset; + } + + // Either underflow or would exceed current_offset + _ => { + self.current_offset = self.end_offset; + return None; + } + } + + self.next_back() + } } /// Iterator of contiguous ranges of set bits within a provided packed bitmask @@ -327,6 +414,12 @@ pub fn try_for_each_valid_idx Result<(), E>>( #[cfg(test)] mod tests { use super::*; + use crate::BooleanBuffer; + use rand::rngs::StdRng; + use rand::{Rng, SeedableRng}; + use std::fmt::Debug; + use std::iter::Copied; + use std::slice::Iter; #[test] fn test_bit_iterator_size_hint() { @@ -486,4 +579,426 @@ mod tests { .collect(); assert_eq!(result, expected); } + + trait SharedBetweenBitIteratorAndSliceIter: + ExactSizeIterator + DoubleEndedIterator + { + } + impl + DoubleEndedIterator> + SharedBetweenBitIteratorAndSliceIter for T + { + } + + fn mutate_2_iters( + mut bit_iterator: BitIterator, + mut source: T, + mutate_fn: impl Fn(&mut dyn SharedBetweenBitIteratorAndSliceIter), + ) -> (BitIterator, T) { + mutate_fn(&mut bit_iterator); + mutate_fn(&mut source); + + (bit_iterator, source) + } + + fn get_bit_iterator_cases() -> impl Iterator)> { + let mut rng = StdRng::seed_from_u64(42); + + [0, 1, 6, 8, 100, 164] + .map(|len| { + let source = (0..len).map(|_| rng.random_bool(0.5)).collect::>(); + + (BooleanBuffer::from(source.as_slice()), source) + }) + .into_iter() + } + + fn setup_and_assert( + setup_iters: impl Fn(&mut dyn SharedBetweenBitIteratorAndSliceIter), + assert_fn: impl Fn(BitIterator, Copied>), + ) { + for (boolean_buffer, source) in get_bit_iterator_cases() { + let (actual, expected) = mutate_2_iters( + BitIterator::new(boolean_buffer.values(), 0, boolean_buffer.len()), + source.iter().copied(), + &setup_iters, + ); + + assert_fn(actual, expected); + } + } + + trait Op { + type Output: PartialEq + Debug; + const NAME: &'static str; + + fn get_value(iter: T) -> Self::Output; + } + + fn assert_cases() { + setup_and_assert( + |_iter: &mut dyn SharedBetweenBitIteratorAndSliceIter| {}, + |actual, expected| { + let current_iterator_values: Vec = expected.clone().collect(); + assert_eq!( + O::get_value(actual), + O::get_value(expected), + "Failed on op {} for new iter (left actual, right expected) ({current_iterator_values:?})", + O::NAME + ); + }, + ); + + setup_and_assert( + |iter: &mut dyn SharedBetweenBitIteratorAndSliceIter| { + iter.next(); + }, + |actual, expected| { + let current_iterator_values: Vec = expected.clone().collect(); + + assert_eq!( + O::get_value(actual), + O::get_value(expected), + "Failed on op {} for new iter after consuming 1 element from the start (left actual, right expected) ({current_iterator_values:?})", + O::NAME + ); + }, + ); + + setup_and_assert( + |iter: &mut dyn SharedBetweenBitIteratorAndSliceIter| { + iter.next_back(); + }, + |actual, expected| { + let current_iterator_values: Vec = expected.clone().collect(); + + assert_eq!( + O::get_value(actual), + O::get_value(expected), + "Failed on op {} for new iter after consuming 1 element from the end (left actual, right expected) ({current_iterator_values:?})", + O::NAME + ); + }, + ); + + setup_and_assert( + |iter: &mut dyn SharedBetweenBitIteratorAndSliceIter| { + iter.next(); + iter.next_back(); + }, + |actual, expected| { + let current_iterator_values: Vec = expected.clone().collect(); + + assert_eq!( + O::get_value(actual), + O::get_value(expected), + "Failed on op {} for new iter after consuming 1 element from start and end (left actual, right expected) ({current_iterator_values:?})", + O::NAME + ); + }, + ); + + setup_and_assert( + |iter: &mut dyn SharedBetweenBitIteratorAndSliceIter| { + while iter.len() > 1 { + iter.next(); + } + }, + |actual, expected| { + let current_iterator_values: Vec = expected.clone().collect(); + + assert_eq!( + O::get_value(actual), + O::get_value(expected), + "Failed on op {} for new iter after consuming all from the start but 1 (left actual, right expected) ({current_iterator_values:?})", + O::NAME + ); + }, + ); + + setup_and_assert( + |iter: &mut dyn SharedBetweenBitIteratorAndSliceIter| { + while iter.len() > 1 { + iter.next_back(); + } + }, + |actual, expected| { + let current_iterator_values: Vec = expected.clone().collect(); + + assert_eq!( + O::get_value(actual), + O::get_value(expected), + "Failed on op {} for new iter after consuming all from the end but 1 (left actual, right expected) ({current_iterator_values:?})", + O::NAME + ); + }, + ); + + setup_and_assert( + |iter: &mut dyn SharedBetweenBitIteratorAndSliceIter| { + while iter.next().is_some() {} + }, + |actual, expected| { + let current_iterator_values: Vec = expected.clone().collect(); + + assert_eq!( + O::get_value(actual), + O::get_value(expected), + "Failed on op {} for new iter after consuming all from the start (left actual, right expected) ({current_iterator_values:?})", + O::NAME + ); + }, + ); + + setup_and_assert( + |iter: &mut dyn SharedBetweenBitIteratorAndSliceIter| { + while iter.next_back().is_some() {} + }, + |actual, expected| { + let current_iterator_values: Vec = expected.clone().collect(); + + assert_eq!( + O::get_value(actual), + O::get_value(expected), + "Failed on op {} for new iter after consuming all from the end (left actual, right expected) ({current_iterator_values:?})", + O::NAME + ); + }, + ); + } + + #[test] + fn assert_bit_iterator_count() { + struct CountOp; + + impl Op for CountOp { + type Output = usize; + const NAME: &'static str = "count"; + + fn get_value(iter: T) -> Self::Output { + iter.count() + } + } + + assert_cases::() + } + + #[test] + fn assert_bit_iterator_last() { + struct LastOp; + + impl Op for LastOp { + type Output = Option; + const NAME: &'static str = "last"; + + fn get_value(iter: T) -> Self::Output { + iter.last() + } + } + + assert_cases::() + } + + #[test] + fn assert_bit_iterator_max() { + struct MaxOp; + + impl Op for MaxOp { + type Output = Option; + const NAME: &'static str = "max"; + + fn get_value(iter: T) -> Self::Output { + iter.max() + } + } + + assert_cases::() + } + + #[test] + fn assert_bit_iterator_nth_0() { + struct NthOp; + + impl Op for NthOp { + type Output = Option; + const NAME: &'static str = if BACK { "nth_back(0)" } else { "nth(0)" }; + + fn get_value(mut iter: T) -> Self::Output { + if BACK { iter.nth_back(0) } else { iter.nth(0) } + } + } + + assert_cases::>(); + assert_cases::>(); + } + + #[test] + fn assert_bit_iterator_nth_1() { + struct NthOp; + + impl Op for NthOp { + type Output = Option; + const NAME: &'static str = if BACK { "nth_back(1)" } else { "nth(1)" }; + + fn get_value(mut iter: T) -> Self::Output { + if BACK { iter.nth_back(1) } else { iter.nth(1) } + } + } + + assert_cases::>(); + assert_cases::>(); + } + + #[test] + fn assert_bit_iterator_nth_after_end() { + struct NthOp; + + impl Op for NthOp { + type Output = Option; + const NAME: &'static str = if BACK { + "nth_back(iter.len() + 1)" + } else { + "nth(iter.len() + 1)" + }; + + fn get_value(mut iter: T) -> Self::Output { + if BACK { + iter.nth_back(iter.len() + 1) + } else { + iter.nth(iter.len() + 1) + } + } + } + + assert_cases::>(); + assert_cases::>(); + } + + #[test] + fn assert_bit_iterator_nth_len() { + struct NthOp; + + impl Op for NthOp { + type Output = Option; + const NAME: &'static str = if BACK { + "nth_back(iter.len())" + } else { + "nth(iter.len())" + }; + + fn get_value(mut iter: T) -> Self::Output { + if BACK { + iter.nth_back(iter.len()) + } else { + iter.nth(iter.len()) + } + } + } + + assert_cases::>(); + assert_cases::>(); + } + + #[test] + fn assert_bit_iterator_nth_last() { + struct NthOp; + + impl Op for NthOp { + type Output = Option; + const NAME: &'static str = if BACK { + "nth_back(iter.len().saturating_sub(1))" + } else { + "nth(iter.len().saturating_sub(1))" + }; + + fn get_value(mut iter: T) -> Self::Output { + if BACK { + iter.nth_back(iter.len().saturating_sub(1)) + } else { + iter.nth(iter.len().saturating_sub(1)) + } + } + } + + assert_cases::>(); + assert_cases::>(); + } + + #[test] + fn assert_bit_iterator_nth_and_reuse() { + setup_and_assert( + |_| {}, + |actual, expected| { + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + #[allow(clippy::iter_nth_zero)] + let actual_val = actual.nth(0); + #[allow(clippy::iter_nth_zero)] + let expected_val = expected.nth(0); + assert_eq!(actual_val, expected_val, "Failed on nth(0)"); + } + } + + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + let actual_val = actual.nth(1); + let expected_val = expected.nth(1); + assert_eq!(actual_val, expected_val, "Failed on nth(1)"); + } + } + + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + let actual_val = actual.nth(2); + let expected_val = expected.nth(2); + assert_eq!(actual_val, expected_val, "Failed on nth(2)"); + } + } + }, + ); + } + + #[test] + fn assert_bit_iterator_nth_back_and_reuse() { + setup_and_assert( + |_| {}, + |actual, expected| { + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + #[allow(clippy::iter_nth_zero)] + let actual_val = actual.nth_back(0); + let expected_val = expected.nth_back(0); + assert_eq!(actual_val, expected_val, "Failed on nth_back(0)"); + } + } + + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + let actual_val = actual.nth_back(1); + let expected_val = expected.nth_back(1); + assert_eq!(actual_val, expected_val, "Failed on nth_back(1)"); + } + } + + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + let actual_val = actual.nth_back(2); + let expected_val = expected.nth_back(2); + assert_eq!(actual_val, expected_val, "Failed on nth_back(2)"); + } + } + }, + ); + } }