Skip to content

Commit

Permalink
fix: optimize some bit functions (#718)
Browse files Browse the repository at this point in the history
## Which issue does this PR close?

Part of #679 and #670

## Rationale for this change

The improvement could be negligible in real use cases, but I see some improvements in micro benchmarks 

## What changes are included in this PR?

Optimizations in some bit functions

## How are these changes tested?

Existing tests
  • Loading branch information
kazuyukitanimura authored Aug 2, 2024
1 parent 698c1b2 commit ffb96c3
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 39 deletions.
35 changes: 34 additions & 1 deletion native/core/benches/bit_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ use rand::{thread_rng, Rng};

use arrow::buffer::Buffer;
use comet::common::bit::{
log2, read_num_bytes_u32, read_num_bytes_u64, set_bits, BitReader, BitWriter,
log2, read_num_bytes_u32, read_num_bytes_u64, read_u32, read_u64, set_bits, trailing_bits,
BitReader, BitWriter,
};
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};

Expand Down Expand Up @@ -159,6 +160,38 @@ fn criterion_benchmark(c: &mut Criterion) {
);
}

// trailing_bits
for length in (0..=64).step_by(32) {
let x = length;
group.bench_with_input(
BenchmarkId::new("trailing_bits", format!("num_bits_{}", x)),
&x,
|b, &x| {
b.iter(|| trailing_bits(black_box(1234567890), black_box(x)));
},
);
}

// read_u64
group.bench_function("read_u64", |b| {
b.iter(|| read_u64(black_box(&[0u8; 8])));
});

// read_u32
group.bench_function("read_u32", |b| {
b.iter(|| read_u32(black_box(&[0u8; 4])));
});

// get_u32_value
group.bench_function("get_u32_value", |b| {
b.iter(|| {
let mut reader: BitReader = BitReader::new_all(buffer.slice(0));
for _ in 0..(buffer.len() * 8 / 31) {
black_box(reader.get_u32_value(black_box(31)));
}
})
});

group.finish();
}

Expand Down
96 changes: 61 additions & 35 deletions native/core/src/common/bit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use std::{cmp, cmp::min, mem::size_of};
use std::{cmp::min, mem::size_of};

use arrow::buffer::Buffer;

Expand Down Expand Up @@ -131,6 +131,18 @@ pub fn read_num_bytes_u32(size: usize, src: &[u8]) -> u32 {
trailing_bits(v as u64, size * 8) as u32
}

#[inline]
pub fn read_u64(src: &[u8]) -> u64 {
let in_ptr = src.as_ptr() as *const u64;
unsafe { in_ptr.read_unaligned() }
}

#[inline]
pub fn read_u32(src: &[u8]) -> u32 {
let in_ptr = src.as_ptr() as *const u32;
unsafe { in_ptr.read_unaligned() }
}

/// Similar to the `read_num_bytes` but read nums from bytes in big-endian order
/// This is used to read bytes from Java's OutputStream which writes bytes in big-endian
macro_rules! read_num_be_bytes {
Expand Down Expand Up @@ -189,8 +201,7 @@ pub fn trailing_bits(v: u64, num_bits: usize) -> u64 {
if unlikely(num_bits >= 64) {
return v;
}
let n = 64 - num_bits;
(v << n) >> n
v & ((1 << num_bits) - 1)
}

#[inline]
Expand Down Expand Up @@ -555,8 +566,11 @@ pub struct BitReader {
/// either byte aligned or not.
impl BitReader {
pub fn new(buf: Buffer, len: usize) -> Self {
let num_bytes = cmp::min(8, len);
let buffered_values = read_num_bytes_u64(num_bytes, buf.as_slice());
let buffered_values = if size_of::<u64>() > len {
read_num_bytes_u64(len, buf.as_slice())
} else {
read_u64(buf.as_slice())
};
BitReader {
buffer: buf,
buffered_values,
Expand All @@ -574,8 +588,11 @@ impl BitReader {
pub fn reset(&mut self, buf: Buffer) {
self.buffer = buf;
self.total_bytes = self.buffer.len();
let num_bytes = cmp::min(8, self.total_bytes);
self.buffered_values = read_num_bytes_u64(num_bytes, self.buffer.as_slice());
self.buffered_values = if size_of::<u64>() > self.total_bytes {
read_num_bytes_u64(self.total_bytes, self.buffer.as_slice())
} else {
read_u64(self.buffer.as_slice())
};
self.byte_offset = 0;
self.bit_offset = 0;
}
Expand All @@ -597,19 +614,7 @@ impl BitReader {
return None;
}

let mut v =
trailing_bits(self.buffered_values, self.bit_offset + num_bits) >> self.bit_offset;
self.bit_offset += num_bits;

if self.bit_offset >= 64 {
self.byte_offset += 8;
self.bit_offset -= 64;

self.reload_buffer_values();
v |= trailing_bits(self.buffered_values, self.bit_offset)
.wrapping_shl((num_bits - self.bit_offset) as u32);
}

let v = self.get_u64_value(num_bits);
Some(T::from(v))
}

Expand All @@ -625,20 +630,26 @@ impl BitReader {
/// Undefined behavior will happen if any of the above assumptions is violated.
#[inline]
pub fn get_u32_value(&mut self, num_bits: usize) -> u32 {
let mut v =
trailing_bits(self.buffered_values, self.bit_offset + num_bits) >> self.bit_offset;
self.bit_offset += num_bits;

if self.bit_offset >= 64 {
self.byte_offset += 8;
self.bit_offset -= 64;
self.get_u64_value(num_bits) as u32
}

self.reload_buffer_values();
v |= trailing_bits(self.buffered_values, self.bit_offset)
.wrapping_shl((num_bits - self.bit_offset) as u32);
#[inline(always)]
fn get_u64_value(&mut self, num_bits: usize) -> u64 {
if unlikely(num_bits == 0) {
0
} else {
let v = self.buffered_values >> self.bit_offset;
let mask = u64::MAX >> (64 - num_bits);
self.bit_offset += num_bits;
if self.bit_offset < 64 {
v & mask
} else {
self.byte_offset += 8;
self.bit_offset -= 64;
self.reload_buffer_values();
((self.buffered_values << (num_bits - self.bit_offset)) | v) & mask
}
}

v as u32
}

/// Gets at most `num` bits from this reader, and append them to the `dst` byte slice, starting
Expand Down Expand Up @@ -974,9 +985,12 @@ impl BitReader {
}

fn reload_buffer_values(&mut self) {
let bytes_to_read = cmp::min(self.total_bytes - self.byte_offset, 8);
self.buffered_values =
read_num_bytes_u64(bytes_to_read, &self.buffer.as_slice()[self.byte_offset..]);
let bytes_to_read = self.total_bytes - self.byte_offset;
self.buffered_values = if 8 > bytes_to_read {
read_num_bytes_u64(bytes_to_read, &self.buffer.as_slice()[self.byte_offset..])
} else {
read_u64(&self.buffer.as_slice()[self.byte_offset..])
};
}
}

Expand Down Expand Up @@ -1019,6 +1033,12 @@ mod tests {
}
}

#[test]
fn test_read_u64() {
let buffer: Vec<u8> = vec![0, 1, 2, 3, 4, 5, 6, 7];
assert_eq!(read_u64(&buffer), read_num_bytes!(u64, 8, &buffer),);
}

#[test]
fn test_read_num_bytes_u32() {
let buffer: Vec<u8> = vec![0, 1, 2, 3];
Expand All @@ -1030,6 +1050,12 @@ mod tests {
}
}

#[test]
fn test_read_u32() {
let buffer: Vec<u8> = vec![0, 1, 2, 3];
assert_eq!(read_u32(&buffer), read_num_bytes!(u32, 4, &buffer),);
}

#[test]
fn test_ceil() {
assert_eq!(ceil(0, 1), 0);
Expand Down
2 changes: 1 addition & 1 deletion native/core/src/parquet/mutable_vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ impl ParquetMutableVector {
// We need to update offset buffer for binary.
if Self::is_binary_type(&self.arrow_type) {
let mut offset = self.num_values * 4;
let prev_offset_value = bit::read_num_bytes_u32(4, &self.value_buffer[offset..]);
let prev_offset_value = bit::read_u32(&self.value_buffer[offset..]);
offset += 4;
(0..n).for_each(|_| {
bit::memcpy_value(&prev_offset_value, 4, &mut self.value_buffer[offset..]);
Expand Down
4 changes: 2 additions & 2 deletions native/core/src/parquet/read/levels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use parquet::schema::types::ColumnDescPtr;

use super::values::Decoder;
use crate::{
common::bit::{self, read_num_bytes_u32, BitReader},
common::bit::{self, read_u32, BitReader},
parquet::ParquetMutableVector,
unlikely,
};
Expand Down Expand Up @@ -89,7 +89,7 @@ impl LevelDecoder {
0
} else if self.need_length {
let u32_size = mem::size_of::<u32>();
let data_size = read_num_bytes_u32(u32_size, page_data.as_slice()) as usize;
let data_size = read_u32(page_data.as_slice()) as usize;
self.bit_reader = Some(BitReader::new(page_data.slice(u32_size), data_size));
u32_size + data_size
} else {
Expand Down

0 comments on commit ffb96c3

Please sign in to comment.