Skip to content

Commit

Permalink
Avoid unnecessary branching in row read/write if schema is null-free (#…
Browse files Browse the repository at this point in the history
…1891)

* Avoid unnecessary branching in row read/write if schema is null-free

* test null free code path for binary as well

* name nf to null_free
  • Loading branch information
yjshen authored Mar 1, 2022
1 parent 7eb3bd8 commit cc22e17
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 94 deletions.
8 changes: 1 addition & 7 deletions datafusion/benches/jit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ extern crate datafusion;
mod data_utils;
use crate::criterion::Criterion;
use crate::data_utils::{create_record_batches, create_schema};
use datafusion::row::writer::{
bench_write_batch, bench_write_batch_jit, bench_write_batch_jit_dummy,
};
use datafusion::row::writer::{bench_write_batch, bench_write_batch_jit};
use std::sync::Arc;

fn criterion_benchmark(c: &mut Criterion) {
Expand All @@ -48,10 +46,6 @@ fn criterion_benchmark(c: &mut Criterion) {
criterion::black_box(bench_write_batch_jit(&batches, schema.clone()).unwrap())
})
});

c.bench_function("row serializer jit codegen only", |b| {
b.iter(|| bench_write_batch_jit_dummy(schema.clone()).unwrap())
});
}

criterion_group!(benches, criterion_benchmark);
Expand Down
80 changes: 78 additions & 2 deletions datafusion/src/row/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,11 @@ fn fn_name<T>(f: T) -> &'static str {
}
}

/// Tell if schema contains no nullable field
pub fn schema_null_free(schema: &Arc<Schema>) -> bool {
schema.fields().iter().all(|f| !f.is_nullable())
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -323,7 +328,7 @@ mod tests {
#[test]
#[allow(non_snake_case)]
fn [<test_single_ $TYPE>]() -> Result<()> {
let schema = Arc::new(Schema::new(vec![Field::new("a", $TYPE, false)]));
let schema = Arc::new(Schema::new(vec![Field::new("a", $TYPE, true)]));
let a = $ARRAY::from($VEC);
let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(a)])?;
let mut vector = vec![0; 1024];
Expand All @@ -349,6 +354,38 @@ mod tests {
assert_eq!(batch, output_batch);
Ok(())
}

#[test]
#[allow(non_snake_case)]
fn [<test_single_ $TYPE _null_free>]() -> Result<()> {
let schema = Arc::new(Schema::new(vec![Field::new("a", $TYPE, false)]));
let v = $VEC.into_iter().filter(|o| o.is_some()).collect::<Vec<_>>();
let a = $ARRAY::from(v);
let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(a)])?;
let mut vector = vec![0; 1024];
let row_offsets =
{ write_batch_unchecked(&mut vector, 0, &batch, 0, schema.clone()) };
let output_batch = { read_as_batch(&vector, schema, row_offsets)? };
assert_eq!(batch, output_batch);
Ok(())
}

#[test]
#[allow(non_snake_case)]
#[cfg(feature = "jit")]
fn [<test_single_ $TYPE _jit_null_free>]() -> Result<()> {
let schema = Arc::new(Schema::new(vec![Field::new("a", $TYPE, false)]));
let v = $VEC.into_iter().filter(|o| o.is_some()).collect::<Vec<_>>();
let a = $ARRAY::from(v);
let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(a)])?;
let mut vector = vec![0; 1024];
let assembler = Assembler::default();
let row_offsets =
{ write_batch_unchecked_jit(&mut vector, 0, &batch, 0, schema.clone(), &assembler)? };
let output_batch = { read_as_batch_jit(&vector, schema, row_offsets, &assembler)? };
assert_eq!(batch, output_batch);
Ok(())
}
}
};
}
Expand Down Expand Up @@ -439,7 +476,7 @@ mod tests {

#[test]
fn test_single_binary() -> Result<()> {
let schema = Arc::new(Schema::new(vec![Field::new("a", Binary, false)]));
let schema = Arc::new(Schema::new(vec![Field::new("a", Binary, true)]));
let values: Vec<Option<&[u8]>> =
vec![Some(b"one"), Some(b"two"), None, Some(b""), Some(b"three")];
let a = BinaryArray::from_opt_vec(values);
Expand Down Expand Up @@ -478,6 +515,45 @@ mod tests {
Ok(())
}

#[test]
fn test_single_binary_null_free() -> Result<()> {
let schema = Arc::new(Schema::new(vec![Field::new("a", Binary, false)]));
let values: Vec<&[u8]> = vec![b"one", b"two", b"", b"three"];
let a = BinaryArray::from_vec(values);
let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(a)])?;
let mut vector = vec![0; 8192];
let row_offsets =
{ write_batch_unchecked(&mut vector, 0, &batch, 0, schema.clone()) };
let output_batch = { read_as_batch(&vector, schema, row_offsets)? };
assert_eq!(batch, output_batch);
Ok(())
}

#[test]
#[cfg(feature = "jit")]
fn test_single_binary_jit_null_free() -> Result<()> {
let schema = Arc::new(Schema::new(vec![Field::new("a", Binary, false)]));
let values: Vec<&[u8]> = vec![b"one", b"two", b"", b"three"];
let a = BinaryArray::from_vec(values);
let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(a)])?;
let mut vector = vec![0; 8192];
let assembler = Assembler::default();
let row_offsets = {
write_batch_unchecked_jit(
&mut vector,
0,
&batch,
0,
schema.clone(),
&assembler,
)?
};
let output_batch =
{ read_as_batch_jit(&vector, schema, row_offsets, &assembler)? };
assert_eq!(batch, output_batch);
Ok(())
}

#[tokio::test]
async fn test_with_parquet() -> Result<()> {
let runtime = Arc::new(RuntimeEnv::default());
Expand Down
142 changes: 82 additions & 60 deletions datafusion/src/row/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ use crate::error::{DataFusionError, Result};
use crate::reg_fn;
#[cfg(feature = "jit")]
use crate::row::fn_name;
use crate::row::{all_valid, get_offsets, supported, NullBitsFormatter};
use crate::row::{
all_valid, get_offsets, schema_null_free, supported, NullBitsFormatter,
};
use arrow::array::*;
use arrow::datatypes::{DataType, Schema};
use arrow::error::Result as ArrowResult;
Expand Down Expand Up @@ -133,32 +135,40 @@ pub struct RowReader<'a> {
/// For fixed length fields, it's where the actual data stores.
/// For variable length fields, it's a pack of (offset << 32 | length) if we use u64.
field_offsets: Vec<usize>,
/// If a row is null free according to its schema
null_free: bool,
}

impl<'a> std::fmt::Debug for RowReader<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let null_bits = self.null_bits();
write!(
f,
"{:?}",
NullBitsFormatter::new(null_bits, self.field_count)
)
if self.null_free {
write!(f, "null_free")
} else {
let null_bits = self.null_bits();
write!(
f,
"{:?}",
NullBitsFormatter::new(null_bits, self.field_count)
)
}
}
}

impl<'a> RowReader<'a> {
/// new
pub fn new(schema: &Arc<Schema>, data: &'a [u8]) -> Self {
assert!(supported(schema));
let null_free = schema_null_free(schema);
let field_count = schema.fields().len();
let null_width = ceil(field_count, 8);
let null_width = if null_free { 0 } else { ceil(field_count, 8) };
let (field_offsets, _) = get_offsets(null_width, schema);
Self {
data,
base_offset: 0,
field_count,
null_width,
field_offsets,
null_free,
}
}

Expand All @@ -174,14 +184,22 @@ impl<'a> RowReader<'a> {

#[inline(always)]
fn null_bits(&self) -> &[u8] {
let start = self.base_offset;
&self.data[start..start + self.null_width]
if self.null_free {
&[]
} else {
let start = self.base_offset;
&self.data[start..start + self.null_width]
}
}

#[inline(always)]
fn all_valid(&self) -> bool {
let null_bits = self.null_bits();
all_valid(null_bits, self.field_count)
if self.null_free {
true
} else {
let null_bits = self.null_bits();
all_valid(null_bits, self.field_count)
}
}

fn is_valid_at(&self, idx: usize) -> bool {
Expand Down Expand Up @@ -276,7 +294,7 @@ impl<'a> RowReader<'a> {
}

fn read_row(row: &RowReader, batch: &mut MutableRecordBatch, schema: &Arc<Schema>) {
if row.all_valid() {
if row.null_free || row.all_valid() {
for ((col_idx, to), field) in batch
.arrays
.iter_mut()
Expand Down Expand Up @@ -325,21 +343,21 @@ fn register_read_functions(asm: &Assembler) -> Result<()> {
reg_fn!(asm, read_field_date64, reader_param.clone(), None);
reg_fn!(asm, read_field_utf8, reader_param.clone(), None);
reg_fn!(asm, read_field_binary, reader_param.clone(), None);
reg_fn!(asm, read_field_bool_nf, reader_param.clone(), None);
reg_fn!(asm, read_field_u8_nf, reader_param.clone(), None);
reg_fn!(asm, read_field_u16_nf, reader_param.clone(), None);
reg_fn!(asm, read_field_u32_nf, reader_param.clone(), None);
reg_fn!(asm, read_field_u64_nf, reader_param.clone(), None);
reg_fn!(asm, read_field_i8_nf, reader_param.clone(), None);
reg_fn!(asm, read_field_i16_nf, reader_param.clone(), None);
reg_fn!(asm, read_field_i32_nf, reader_param.clone(), None);
reg_fn!(asm, read_field_i64_nf, reader_param.clone(), None);
reg_fn!(asm, read_field_f32_nf, reader_param.clone(), None);
reg_fn!(asm, read_field_f64_nf, reader_param.clone(), None);
reg_fn!(asm, read_field_date32_nf, reader_param.clone(), None);
reg_fn!(asm, read_field_date64_nf, reader_param.clone(), None);
reg_fn!(asm, read_field_utf8_nf, reader_param.clone(), None);
reg_fn!(asm, read_field_binary_nf, reader_param, None);
reg_fn!(asm, read_field_bool_null_free, reader_param.clone(), None);
reg_fn!(asm, read_field_u8_null_free, reader_param.clone(), None);
reg_fn!(asm, read_field_u16_null_free, reader_param.clone(), None);
reg_fn!(asm, read_field_u32_null_free, reader_param.clone(), None);
reg_fn!(asm, read_field_u64_null_free, reader_param.clone(), None);
reg_fn!(asm, read_field_i8_null_free, reader_param.clone(), None);
reg_fn!(asm, read_field_i16_null_free, reader_param.clone(), None);
reg_fn!(asm, read_field_i32_null_free, reader_param.clone(), None);
reg_fn!(asm, read_field_i64_null_free, reader_param.clone(), None);
reg_fn!(asm, read_field_f32_null_free, reader_param.clone(), None);
reg_fn!(asm, read_field_f64_null_free, reader_param.clone(), None);
reg_fn!(asm, read_field_date32_null_free, reader_param.clone(), None);
reg_fn!(asm, read_field_date64_null_free, reader_param.clone(), None);
reg_fn!(asm, read_field_utf8_null_free, reader_param.clone(), None);
reg_fn!(asm, read_field_binary_null_free, reader_param, None);
Ok(())
}

Expand Down Expand Up @@ -383,21 +401,21 @@ fn gen_read_row(
}
} else {
match dt {
Boolean => b.call_stmt("read_field_bool_nf", params)?,
UInt8 => b.call_stmt("read_field_u8_nf", params)?,
UInt16 => b.call_stmt("read_field_u16_nf", params)?,
UInt32 => b.call_stmt("read_field_u32_nf", params)?,
UInt64 => b.call_stmt("read_field_u64_nf", params)?,
Int8 => b.call_stmt("read_field_i8_nf", params)?,
Int16 => b.call_stmt("read_field_i16_nf", params)?,
Int32 => b.call_stmt("read_field_i32_nf", params)?,
Int64 => b.call_stmt("read_field_i64_nf", params)?,
Float32 => b.call_stmt("read_field_f32_nf", params)?,
Float64 => b.call_stmt("read_field_f64_nf", params)?,
Date32 => b.call_stmt("read_field_date32_nf", params)?,
Date64 => b.call_stmt("read_field_date64_nf", params)?,
Utf8 => b.call_stmt("read_field_utf8_nf", params)?,
Binary => b.call_stmt("read_field_binary_nf", params)?,
Boolean => b.call_stmt("read_field_bool_null_free", params)?,
UInt8 => b.call_stmt("read_field_u8_null_free", params)?,
UInt16 => b.call_stmt("read_field_u16_null_free", params)?,
UInt32 => b.call_stmt("read_field_u32_null_free", params)?,
UInt64 => b.call_stmt("read_field_u64_null_free", params)?,
Int8 => b.call_stmt("read_field_i8_null_free", params)?,
Int16 => b.call_stmt("read_field_i16_null_free", params)?,
Int32 => b.call_stmt("read_field_i32_null_free", params)?,
Int64 => b.call_stmt("read_field_i64_null_free", params)?,
Float32 => b.call_stmt("read_field_f32_null_free", params)?,
Float64 => b.call_stmt("read_field_f64_null_free", params)?,
Date32 => b.call_stmt("read_field_date32_null_free", params)?,
Date64 => b.call_stmt("read_field_date64_null_free", params)?,
Utf8 => b.call_stmt("read_field_utf8_null_free", params)?,
Binary => b.call_stmt("read_field_binary_null_free", params)?,
_ => unimplemented!(),
}
}
Expand All @@ -418,7 +436,7 @@ macro_rules! fn_read_field {
.unwrap();
}

fn [<read_field_ $NATIVE _nf>](to: &mut Box<dyn ArrayBuilder>, col_idx: usize, row: &RowReader) {
fn [<read_field_ $NATIVE _null_free>](to: &mut Box<dyn ArrayBuilder>, col_idx: usize, row: &RowReader) {
let to = to
.as_any_mut()
.downcast_mut::<$ARRAY>()
Expand Down Expand Up @@ -455,7 +473,11 @@ fn read_field_binary(to: &mut Box<dyn ArrayBuilder>, col_idx: usize, row: &RowRe
}
}

fn read_field_binary_nf(to: &mut Box<dyn ArrayBuilder>, col_idx: usize, row: &RowReader) {
fn read_field_binary_null_free(
to: &mut Box<dyn ArrayBuilder>,
col_idx: usize,
row: &RowReader,
) {
let to = to.as_any_mut().downcast_mut::<BinaryBuilder>().unwrap();
to.append_value(row.get_binary(col_idx))
.map_err(DataFusionError::ArrowError)
Expand Down Expand Up @@ -497,21 +519,21 @@ fn read_field_null_free(
) {
use DataType::*;
match dt {
Boolean => read_field_bool_nf(to, col_idx, row),
UInt8 => read_field_u8_nf(to, col_idx, row),
UInt16 => read_field_u16_nf(to, col_idx, row),
UInt32 => read_field_u32_nf(to, col_idx, row),
UInt64 => read_field_u64_nf(to, col_idx, row),
Int8 => read_field_i8_nf(to, col_idx, row),
Int16 => read_field_i16_nf(to, col_idx, row),
Int32 => read_field_i32_nf(to, col_idx, row),
Int64 => read_field_i64_nf(to, col_idx, row),
Float32 => read_field_f32_nf(to, col_idx, row),
Float64 => read_field_f64_nf(to, col_idx, row),
Date32 => read_field_date32_nf(to, col_idx, row),
Date64 => read_field_date64_nf(to, col_idx, row),
Utf8 => read_field_utf8_nf(to, col_idx, row),
Binary => read_field_binary_nf(to, col_idx, row),
Boolean => read_field_bool_null_free(to, col_idx, row),
UInt8 => read_field_u8_null_free(to, col_idx, row),
UInt16 => read_field_u16_null_free(to, col_idx, row),
UInt32 => read_field_u32_null_free(to, col_idx, row),
UInt64 => read_field_u64_null_free(to, col_idx, row),
Int8 => read_field_i8_null_free(to, col_idx, row),
Int16 => read_field_i16_null_free(to, col_idx, row),
Int32 => read_field_i32_null_free(to, col_idx, row),
Int64 => read_field_i64_null_free(to, col_idx, row),
Float32 => read_field_f32_null_free(to, col_idx, row),
Float64 => read_field_f64_null_free(to, col_idx, row),
Date32 => read_field_date32_null_free(to, col_idx, row),
Date64 => read_field_date64_null_free(to, col_idx, row),
Utf8 => read_field_utf8_null_free(to, col_idx, row),
Binary => read_field_binary_null_free(to, col_idx, row),
_ => unimplemented!(),
}
}
Expand Down
Loading

0 comments on commit cc22e17

Please sign in to comment.