Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid unnecessary branching in row read/write if schema is null-free #1891

Merged
merged 3 commits into from
Mar 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions datafusion/benches/jit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ extern crate datafusion;
mod data_utils;
use crate::criterion::Criterion;
use crate::data_utils::{create_record_batches, create_schema};
use datafusion::row::writer::{
bench_write_batch, bench_write_batch_jit, bench_write_batch_jit_dummy,
};
use datafusion::row::writer::{bench_write_batch, bench_write_batch_jit};
use std::sync::Arc;

fn criterion_benchmark(c: &mut Criterion) {
Expand All @@ -48,10 +46,6 @@ fn criterion_benchmark(c: &mut Criterion) {
criterion::black_box(bench_write_batch_jit(&batches, schema.clone()).unwrap())
})
});

c.bench_function("row serializer jit codegen only", |b| {
b.iter(|| bench_write_batch_jit_dummy(schema.clone()).unwrap())
});
}

criterion_group!(benches, criterion_benchmark);
Expand Down
80 changes: 78 additions & 2 deletions datafusion/src/row/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,11 @@ fn fn_name<T>(f: T) -> &'static str {
}
}

/// Tell if schema contains no nullable field
pub fn schema_null_free(schema: &Arc<Schema>) -> bool {
schema.fields().iter().all(|f| !f.is_nullable())
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -323,7 +328,7 @@ mod tests {
#[test]
#[allow(non_snake_case)]
fn [<test_single_ $TYPE>]() -> Result<()> {
let schema = Arc::new(Schema::new(vec![Field::new("a", $TYPE, false)]));
let schema = Arc::new(Schema::new(vec![Field::new("a", $TYPE, true)]));
let a = $ARRAY::from($VEC);
let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(a)])?;
let mut vector = vec![0; 1024];
Expand All @@ -349,6 +354,38 @@ mod tests {
assert_eq!(batch, output_batch);
Ok(())
}

#[test]
#[allow(non_snake_case)]
fn [<test_single_ $TYPE _null_free>]() -> Result<()> {
let schema = Arc::new(Schema::new(vec![Field::new("a", $TYPE, false)]));
let v = $VEC.into_iter().filter(|o| o.is_some()).collect::<Vec<_>>();
let a = $ARRAY::from(v);
let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(a)])?;
let mut vector = vec![0; 1024];
let row_offsets =
{ write_batch_unchecked(&mut vector, 0, &batch, 0, schema.clone()) };
let output_batch = { read_as_batch(&vector, schema, row_offsets)? };
assert_eq!(batch, output_batch);
Ok(())
}

#[test]
#[allow(non_snake_case)]
#[cfg(feature = "jit")]
fn [<test_single_ $TYPE _jit_null_free>]() -> Result<()> {
let schema = Arc::new(Schema::new(vec![Field::new("a", $TYPE, false)]));
let v = $VEC.into_iter().filter(|o| o.is_some()).collect::<Vec<_>>();
let a = $ARRAY::from(v);
let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(a)])?;
let mut vector = vec![0; 1024];
let assembler = Assembler::default();
let row_offsets =
{ write_batch_unchecked_jit(&mut vector, 0, &batch, 0, schema.clone(), &assembler)? };
let output_batch = { read_as_batch_jit(&vector, schema, row_offsets, &assembler)? };
assert_eq!(batch, output_batch);
Ok(())
}
}
};
}
Expand Down Expand Up @@ -439,7 +476,7 @@ mod tests {

#[test]
fn test_single_binary() -> Result<()> {
let schema = Arc::new(Schema::new(vec![Field::new("a", Binary, false)]));
let schema = Arc::new(Schema::new(vec![Field::new("a", Binary, true)]));
yjshen marked this conversation as resolved.
Show resolved Hide resolved
let values: Vec<Option<&[u8]>> =
vec![Some(b"one"), Some(b"two"), None, Some(b""), Some(b"three")];
let a = BinaryArray::from_opt_vec(values);
Expand Down Expand Up @@ -478,6 +515,45 @@ mod tests {
Ok(())
}

#[test]
fn test_single_binary_null_free() -> Result<()> {
let schema = Arc::new(Schema::new(vec![Field::new("a", Binary, false)]));
let values: Vec<&[u8]> = vec![b"one", b"two", b"", b"three"];
let a = BinaryArray::from_vec(values);
let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(a)])?;
let mut vector = vec![0; 8192];
let row_offsets =
{ write_batch_unchecked(&mut vector, 0, &batch, 0, schema.clone()) };
let output_batch = { read_as_batch(&vector, schema, row_offsets)? };
assert_eq!(batch, output_batch);
Ok(())
}

#[test]
#[cfg(feature = "jit")]
fn test_single_binary_jit_null_free() -> Result<()> {
let schema = Arc::new(Schema::new(vec![Field::new("a", Binary, false)]));
let values: Vec<&[u8]> = vec![b"one", b"two", b"", b"three"];
let a = BinaryArray::from_vec(values);
let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(a)])?;
let mut vector = vec![0; 8192];
let assembler = Assembler::default();
let row_offsets = {
write_batch_unchecked_jit(
&mut vector,
0,
&batch,
0,
schema.clone(),
&assembler,
)?
};
let output_batch =
{ read_as_batch_jit(&vector, schema, row_offsets, &assembler)? };
assert_eq!(batch, output_batch);
Ok(())
}

#[tokio::test]
async fn test_with_parquet() -> Result<()> {
let runtime = Arc::new(RuntimeEnv::default());
Expand Down
142 changes: 82 additions & 60 deletions datafusion/src/row/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ use crate::error::{DataFusionError, Result};
use crate::reg_fn;
#[cfg(feature = "jit")]
use crate::row::fn_name;
use crate::row::{all_valid, get_offsets, supported, NullBitsFormatter};
use crate::row::{
all_valid, get_offsets, schema_null_free, supported, NullBitsFormatter,
};
use arrow::array::*;
use arrow::datatypes::{DataType, Schema};
use arrow::error::Result as ArrowResult;
Expand Down Expand Up @@ -133,32 +135,40 @@ pub struct RowReader<'a> {
/// For fixed length fields, it's where the actual data stores.
/// For variable length fields, it's a pack of (offset << 32 | length) if we use u64.
field_offsets: Vec<usize>,
/// If a row is null free according to its schema
null_free: bool,
}

impl<'a> std::fmt::Debug for RowReader<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let null_bits = self.null_bits();
write!(
f,
"{:?}",
NullBitsFormatter::new(null_bits, self.field_count)
)
if self.null_free {
write!(f, "null_free")
} else {
let null_bits = self.null_bits();
write!(
f,
"{:?}",
NullBitsFormatter::new(null_bits, self.field_count)
)
}
}
}

impl<'a> RowReader<'a> {
/// new
pub fn new(schema: &Arc<Schema>, data: &'a [u8]) -> Self {
assert!(supported(schema));
let null_free = schema_null_free(schema);
let field_count = schema.fields().len();
let null_width = ceil(field_count, 8);
let null_width = if null_free { 0 } else { ceil(field_count, 8) };
let (field_offsets, _) = get_offsets(null_width, schema);
Self {
data,
base_offset: 0,
field_count,
null_width,
field_offsets,
null_free,
}
}

Expand All @@ -174,14 +184,22 @@ impl<'a> RowReader<'a> {

#[inline(always)]
fn null_bits(&self) -> &[u8] {
let start = self.base_offset;
&self.data[start..start + self.null_width]
if self.null_free {
&[]
} else {
let start = self.base_offset;
&self.data[start..start + self.null_width]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if null_width is always zero, I wonder if the check for self.null_free is needed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for not null_free code path. Actually this method shouldn't be touched when tuples are null-free

}
}

#[inline(always)]
fn all_valid(&self) -> bool {
let null_bits = self.null_bits();
all_valid(null_bits, self.field_count)
if self.null_free {
true
} else {
let null_bits = self.null_bits();
all_valid(null_bits, self.field_count)
}
}

fn is_valid_at(&self, idx: usize) -> bool {
Expand Down Expand Up @@ -276,7 +294,7 @@ impl<'a> RowReader<'a> {
}

fn read_row(row: &RowReader, batch: &mut MutableRecordBatch, schema: &Arc<Schema>) {
if row.all_valid() {
if row.null_free || row.all_valid() {
for ((col_idx, to), field) in batch
.arrays
.iter_mut()
Expand Down Expand Up @@ -325,21 +343,21 @@ fn register_read_functions(asm: &Assembler) -> Result<()> {
reg_fn!(asm, read_field_date64, reader_param.clone(), None);
reg_fn!(asm, read_field_utf8, reader_param.clone(), None);
reg_fn!(asm, read_field_binary, reader_param.clone(), None);
reg_fn!(asm, read_field_bool_nf, reader_param.clone(), None);
reg_fn!(asm, read_field_u8_nf, reader_param.clone(), None);
reg_fn!(asm, read_field_u16_nf, reader_param.clone(), None);
reg_fn!(asm, read_field_u32_nf, reader_param.clone(), None);
reg_fn!(asm, read_field_u64_nf, reader_param.clone(), None);
reg_fn!(asm, read_field_i8_nf, reader_param.clone(), None);
reg_fn!(asm, read_field_i16_nf, reader_param.clone(), None);
reg_fn!(asm, read_field_i32_nf, reader_param.clone(), None);
reg_fn!(asm, read_field_i64_nf, reader_param.clone(), None);
reg_fn!(asm, read_field_f32_nf, reader_param.clone(), None);
reg_fn!(asm, read_field_f64_nf, reader_param.clone(), None);
reg_fn!(asm, read_field_date32_nf, reader_param.clone(), None);
reg_fn!(asm, read_field_date64_nf, reader_param.clone(), None);
reg_fn!(asm, read_field_utf8_nf, reader_param.clone(), None);
reg_fn!(asm, read_field_binary_nf, reader_param, None);
reg_fn!(asm, read_field_bool_null_free, reader_param.clone(), None);
reg_fn!(asm, read_field_u8_null_free, reader_param.clone(), None);
reg_fn!(asm, read_field_u16_null_free, reader_param.clone(), None);
reg_fn!(asm, read_field_u32_null_free, reader_param.clone(), None);
reg_fn!(asm, read_field_u64_null_free, reader_param.clone(), None);
reg_fn!(asm, read_field_i8_null_free, reader_param.clone(), None);
reg_fn!(asm, read_field_i16_null_free, reader_param.clone(), None);
reg_fn!(asm, read_field_i32_null_free, reader_param.clone(), None);
reg_fn!(asm, read_field_i64_null_free, reader_param.clone(), None);
reg_fn!(asm, read_field_f32_null_free, reader_param.clone(), None);
reg_fn!(asm, read_field_f64_null_free, reader_param.clone(), None);
reg_fn!(asm, read_field_date32_null_free, reader_param.clone(), None);
reg_fn!(asm, read_field_date64_null_free, reader_param.clone(), None);
reg_fn!(asm, read_field_utf8_null_free, reader_param.clone(), None);
reg_fn!(asm, read_field_binary_null_free, reader_param, None);
Ok(())
}

Expand Down Expand Up @@ -383,21 +401,21 @@ fn gen_read_row(
}
} else {
match dt {
Boolean => b.call_stmt("read_field_bool_nf", params)?,
UInt8 => b.call_stmt("read_field_u8_nf", params)?,
UInt16 => b.call_stmt("read_field_u16_nf", params)?,
UInt32 => b.call_stmt("read_field_u32_nf", params)?,
UInt64 => b.call_stmt("read_field_u64_nf", params)?,
Int8 => b.call_stmt("read_field_i8_nf", params)?,
Int16 => b.call_stmt("read_field_i16_nf", params)?,
Int32 => b.call_stmt("read_field_i32_nf", params)?,
Int64 => b.call_stmt("read_field_i64_nf", params)?,
Float32 => b.call_stmt("read_field_f32_nf", params)?,
Float64 => b.call_stmt("read_field_f64_nf", params)?,
Date32 => b.call_stmt("read_field_date32_nf", params)?,
Date64 => b.call_stmt("read_field_date64_nf", params)?,
Utf8 => b.call_stmt("read_field_utf8_nf", params)?,
Binary => b.call_stmt("read_field_binary_nf", params)?,
Boolean => b.call_stmt("read_field_bool_null_free", params)?,
UInt8 => b.call_stmt("read_field_u8_null_free", params)?,
UInt16 => b.call_stmt("read_field_u16_null_free", params)?,
UInt32 => b.call_stmt("read_field_u32_null_free", params)?,
UInt64 => b.call_stmt("read_field_u64_null_free", params)?,
Int8 => b.call_stmt("read_field_i8_null_free", params)?,
Int16 => b.call_stmt("read_field_i16_null_free", params)?,
Int32 => b.call_stmt("read_field_i32_null_free", params)?,
Int64 => b.call_stmt("read_field_i64_null_free", params)?,
Float32 => b.call_stmt("read_field_f32_null_free", params)?,
Float64 => b.call_stmt("read_field_f64_null_free", params)?,
Date32 => b.call_stmt("read_field_date32_null_free", params)?,
Date64 => b.call_stmt("read_field_date64_null_free", params)?,
Utf8 => b.call_stmt("read_field_utf8_null_free", params)?,
Binary => b.call_stmt("read_field_binary_null_free", params)?,
_ => unimplemented!(),
}
}
Expand All @@ -418,7 +436,7 @@ macro_rules! fn_read_field {
.unwrap();
}

fn [<read_field_ $NATIVE _nf>](to: &mut Box<dyn ArrayBuilder>, col_idx: usize, row: &RowReader) {
fn [<read_field_ $NATIVE _null_free>](to: &mut Box<dyn ArrayBuilder>, col_idx: usize, row: &RowReader) {
let to = to
.as_any_mut()
.downcast_mut::<$ARRAY>()
Expand Down Expand Up @@ -455,7 +473,11 @@ fn read_field_binary(to: &mut Box<dyn ArrayBuilder>, col_idx: usize, row: &RowRe
}
}

fn read_field_binary_nf(to: &mut Box<dyn ArrayBuilder>, col_idx: usize, row: &RowReader) {
fn read_field_binary_null_free(
to: &mut Box<dyn ArrayBuilder>,
col_idx: usize,
row: &RowReader,
) {
let to = to.as_any_mut().downcast_mut::<BinaryBuilder>().unwrap();
to.append_value(row.get_binary(col_idx))
.map_err(DataFusionError::ArrowError)
Expand Down Expand Up @@ -497,21 +519,21 @@ fn read_field_null_free(
) {
use DataType::*;
match dt {
Boolean => read_field_bool_nf(to, col_idx, row),
UInt8 => read_field_u8_nf(to, col_idx, row),
UInt16 => read_field_u16_nf(to, col_idx, row),
UInt32 => read_field_u32_nf(to, col_idx, row),
UInt64 => read_field_u64_nf(to, col_idx, row),
Int8 => read_field_i8_nf(to, col_idx, row),
Int16 => read_field_i16_nf(to, col_idx, row),
Int32 => read_field_i32_nf(to, col_idx, row),
Int64 => read_field_i64_nf(to, col_idx, row),
Float32 => read_field_f32_nf(to, col_idx, row),
Float64 => read_field_f64_nf(to, col_idx, row),
Date32 => read_field_date32_nf(to, col_idx, row),
Date64 => read_field_date64_nf(to, col_idx, row),
Utf8 => read_field_utf8_nf(to, col_idx, row),
Binary => read_field_binary_nf(to, col_idx, row),
Boolean => read_field_bool_null_free(to, col_idx, row),
UInt8 => read_field_u8_null_free(to, col_idx, row),
UInt16 => read_field_u16_null_free(to, col_idx, row),
UInt32 => read_field_u32_null_free(to, col_idx, row),
UInt64 => read_field_u64_null_free(to, col_idx, row),
Int8 => read_field_i8_null_free(to, col_idx, row),
Int16 => read_field_i16_null_free(to, col_idx, row),
Int32 => read_field_i32_null_free(to, col_idx, row),
Int64 => read_field_i64_null_free(to, col_idx, row),
Float32 => read_field_f32_null_free(to, col_idx, row),
Float64 => read_field_f64_null_free(to, col_idx, row),
Date32 => read_field_date32_null_free(to, col_idx, row),
Date64 => read_field_date64_null_free(to, col_idx, row),
Utf8 => read_field_utf8_null_free(to, col_idx, row),
Binary => read_field_binary_null_free(to, col_idx, row),
_ => unimplemented!(),
}
}
Expand Down
Loading