From ffb96c375be9bf1fbce5cf805ffd5f35667a223b Mon Sep 17 00:00:00 2001 From: KAZUYUKI TANIMURA Date: Fri, 2 Aug 2024 14:36:09 -0700 Subject: [PATCH] fix: optimize some bit functions (#718) ## Which issue does this PR close? Part of #679 and #670 ## Rationale for this change The improvement could be negligible in real use cases, but I see some improvements in micro benchmarks ## What changes are included in this PR? Optimizations in some bit functions ## How are these changes tested? Existing tests --- native/core/benches/bit_util.rs | 35 ++++++++- native/core/src/common/bit.rs | 96 ++++++++++++++--------- native/core/src/parquet/mutable_vector.rs | 2 +- native/core/src/parquet/read/levels.rs | 4 +- 4 files changed, 98 insertions(+), 39 deletions(-) diff --git a/native/core/benches/bit_util.rs b/native/core/benches/bit_util.rs index e92dd6375..91471a890 100644 --- a/native/core/benches/bit_util.rs +++ b/native/core/benches/bit_util.rs @@ -21,7 +21,8 @@ use rand::{thread_rng, Rng}; use arrow::buffer::Buffer; use comet::common::bit::{ - log2, read_num_bytes_u32, read_num_bytes_u64, set_bits, BitReader, BitWriter, + log2, read_num_bytes_u32, read_num_bytes_u64, read_u32, read_u64, set_bits, trailing_bits, + BitReader, BitWriter, }; use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; @@ -159,6 +160,38 @@ fn criterion_benchmark(c: &mut Criterion) { ); } + // trailing_bits + for length in (0..=64).step_by(32) { + let x = length; + group.bench_with_input( + BenchmarkId::new("trailing_bits", format!("num_bits_{}", x)), + &x, + |b, &x| { + b.iter(|| trailing_bits(black_box(1234567890), black_box(x))); + }, + ); + } + + // read_u64 + group.bench_function("read_u64", |b| { + b.iter(|| read_u64(black_box(&[0u8; 8]))); + }); + + // read_u32 + group.bench_function("read_u32", |b| { + b.iter(|| read_u32(black_box(&[0u8; 4]))); + }); + + // get_u32_value + group.bench_function("get_u32_value", |b| { + b.iter(|| { + let mut reader: BitReader = BitReader::new_all(buffer.slice(0)); + for _ in 0..(buffer.len() * 8 / 31) { + black_box(reader.get_u32_value(black_box(31))); + } + }) + }); + group.finish(); } diff --git a/native/core/src/common/bit.rs b/native/core/src/common/bit.rs index 59b9e4e43..e3937b4a1 100644 --- a/native/core/src/common/bit.rs +++ b/native/core/src/common/bit.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::{cmp, cmp::min, mem::size_of}; +use std::{cmp::min, mem::size_of}; use arrow::buffer::Buffer; @@ -131,6 +131,18 @@ pub fn read_num_bytes_u32(size: usize, src: &[u8]) -> u32 { trailing_bits(v as u64, size * 8) as u32 } +#[inline] +pub fn read_u64(src: &[u8]) -> u64 { + let in_ptr = src.as_ptr() as *const u64; + unsafe { in_ptr.read_unaligned() } +} + +#[inline] +pub fn read_u32(src: &[u8]) -> u32 { + let in_ptr = src.as_ptr() as *const u32; + unsafe { in_ptr.read_unaligned() } +} + /// Similar to the `read_num_bytes` but read nums from bytes in big-endian order /// This is used to read bytes from Java's OutputStream which writes bytes in big-endian macro_rules! read_num_be_bytes { @@ -189,8 +201,7 @@ pub fn trailing_bits(v: u64, num_bits: usize) -> u64 { if unlikely(num_bits >= 64) { return v; } - let n = 64 - num_bits; - (v << n) >> n + v & ((1 << num_bits) - 1) } #[inline] @@ -555,8 +566,11 @@ pub struct BitReader { /// either byte aligned or not. impl BitReader { pub fn new(buf: Buffer, len: usize) -> Self { - let num_bytes = cmp::min(8, len); - let buffered_values = read_num_bytes_u64(num_bytes, buf.as_slice()); + let buffered_values = if size_of::() > len { + read_num_bytes_u64(len, buf.as_slice()) + } else { + read_u64(buf.as_slice()) + }; BitReader { buffer: buf, buffered_values, @@ -574,8 +588,11 @@ impl BitReader { pub fn reset(&mut self, buf: Buffer) { self.buffer = buf; self.total_bytes = self.buffer.len(); - let num_bytes = cmp::min(8, self.total_bytes); - self.buffered_values = read_num_bytes_u64(num_bytes, self.buffer.as_slice()); + self.buffered_values = if size_of::() > self.total_bytes { + read_num_bytes_u64(self.total_bytes, self.buffer.as_slice()) + } else { + read_u64(self.buffer.as_slice()) + }; self.byte_offset = 0; self.bit_offset = 0; } @@ -597,19 +614,7 @@ impl BitReader { return None; } - let mut v = - trailing_bits(self.buffered_values, self.bit_offset + num_bits) >> self.bit_offset; - self.bit_offset += num_bits; - - if self.bit_offset >= 64 { - self.byte_offset += 8; - self.bit_offset -= 64; - - self.reload_buffer_values(); - v |= trailing_bits(self.buffered_values, self.bit_offset) - .wrapping_shl((num_bits - self.bit_offset) as u32); - } - + let v = self.get_u64_value(num_bits); Some(T::from(v)) } @@ -625,20 +630,26 @@ impl BitReader { /// Undefined behavior will happen if any of the above assumptions is violated. #[inline] pub fn get_u32_value(&mut self, num_bits: usize) -> u32 { - let mut v = - trailing_bits(self.buffered_values, self.bit_offset + num_bits) >> self.bit_offset; - self.bit_offset += num_bits; - - if self.bit_offset >= 64 { - self.byte_offset += 8; - self.bit_offset -= 64; + self.get_u64_value(num_bits) as u32 + } - self.reload_buffer_values(); - v |= trailing_bits(self.buffered_values, self.bit_offset) - .wrapping_shl((num_bits - self.bit_offset) as u32); + #[inline(always)] + fn get_u64_value(&mut self, num_bits: usize) -> u64 { + if unlikely(num_bits == 0) { + 0 + } else { + let v = self.buffered_values >> self.bit_offset; + let mask = u64::MAX >> (64 - num_bits); + self.bit_offset += num_bits; + if self.bit_offset < 64 { + v & mask + } else { + self.byte_offset += 8; + self.bit_offset -= 64; + self.reload_buffer_values(); + ((self.buffered_values << (num_bits - self.bit_offset)) | v) & mask + } } - - v as u32 } /// Gets at most `num` bits from this reader, and append them to the `dst` byte slice, starting @@ -974,9 +985,12 @@ impl BitReader { } fn reload_buffer_values(&mut self) { - let bytes_to_read = cmp::min(self.total_bytes - self.byte_offset, 8); - self.buffered_values = - read_num_bytes_u64(bytes_to_read, &self.buffer.as_slice()[self.byte_offset..]); + let bytes_to_read = self.total_bytes - self.byte_offset; + self.buffered_values = if 8 > bytes_to_read { + read_num_bytes_u64(bytes_to_read, &self.buffer.as_slice()[self.byte_offset..]) + } else { + read_u64(&self.buffer.as_slice()[self.byte_offset..]) + }; } } @@ -1019,6 +1033,12 @@ mod tests { } } + #[test] + fn test_read_u64() { + let buffer: Vec = vec![0, 1, 2, 3, 4, 5, 6, 7]; + assert_eq!(read_u64(&buffer), read_num_bytes!(u64, 8, &buffer),); + } + #[test] fn test_read_num_bytes_u32() { let buffer: Vec = vec![0, 1, 2, 3]; @@ -1030,6 +1050,12 @@ mod tests { } } + #[test] + fn test_read_u32() { + let buffer: Vec = vec![0, 1, 2, 3]; + assert_eq!(read_u32(&buffer), read_num_bytes!(u32, 4, &buffer),); + } + #[test] fn test_ceil() { assert_eq!(ceil(0, 1), 0); diff --git a/native/core/src/parquet/mutable_vector.rs b/native/core/src/parquet/mutable_vector.rs index f1428fd39..7f30d7d87 100644 --- a/native/core/src/parquet/mutable_vector.rs +++ b/native/core/src/parquet/mutable_vector.rs @@ -154,7 +154,7 @@ impl ParquetMutableVector { // We need to update offset buffer for binary. if Self::is_binary_type(&self.arrow_type) { let mut offset = self.num_values * 4; - let prev_offset_value = bit::read_num_bytes_u32(4, &self.value_buffer[offset..]); + let prev_offset_value = bit::read_u32(&self.value_buffer[offset..]); offset += 4; (0..n).for_each(|_| { bit::memcpy_value(&prev_offset_value, 4, &mut self.value_buffer[offset..]); diff --git a/native/core/src/parquet/read/levels.rs b/native/core/src/parquet/read/levels.rs index 303db54c8..3d74b277c 100644 --- a/native/core/src/parquet/read/levels.rs +++ b/native/core/src/parquet/read/levels.rs @@ -22,7 +22,7 @@ use parquet::schema::types::ColumnDescPtr; use super::values::Decoder; use crate::{ - common::bit::{self, read_num_bytes_u32, BitReader}, + common::bit::{self, read_u32, BitReader}, parquet::ParquetMutableVector, unlikely, }; @@ -89,7 +89,7 @@ impl LevelDecoder { 0 } else if self.need_length { let u32_size = mem::size_of::(); - let data_size = read_num_bytes_u32(u32_size, page_data.as_slice()) as usize; + let data_size = read_u32(page_data.as_slice()) as usize; self.bit_reader = Some(BitReader::new(page_data.slice(u32_size), data_size)); u32_size + data_size } else {