Skip to content

Commit

Permalink
Add specializations for null / non null
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Sep 29, 2024
1 parent 5ef1038 commit 36a2003
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 91 deletions.
79 changes: 34 additions & 45 deletions datafusion/physical-plan/src/aggregates/group_values/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
// under the License.

use crate::aggregates::group_values::group_column::{
ByteGroupValueBuilder, GroupColumn, PrimitiveGroupValueBuilder,
ByteGroupValueBuilder, GroupColumn, NonNullPrimitiveGroupValueBuilder,
PrimitiveGroupValueBuilder,
};
use crate::aggregates::group_values::GroupValues;
use ahash::RandomState;
Expand Down Expand Up @@ -116,6 +117,26 @@ impl GroupValuesColumn {
}
}

/// instantiates a [`PrimitiveGroupValueBuilder`] or
/// [`NonNullPrimitiveGroupValueBuilder`] and pushes it into $v
///
/// Arguments:
/// `$v`: the vector to push the new builder into
/// `$nullable`: whether the input can contains nulls
/// `$t`: the primitive type of the builder
///
macro_rules! instantiate_primitive {
($v:expr, $nullable:expr, $t:ty) => {
if $nullable {
let b = PrimitiveGroupValueBuilder::<$t>::new();
$v.push(Box::new(b) as _)
} else {
let b = NonNullPrimitiveGroupValueBuilder::<$t>::new();
$v.push(Box::new(b) as _)
}
};
}

impl GroupValues for GroupValuesColumn {
fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec<usize>) -> Result<()> {
let n_rows = cols[0].len();
Expand All @@ -126,54 +147,22 @@ impl GroupValues for GroupValuesColumn {
for f in self.schema.fields().iter() {
let nullable = f.is_nullable();
match f.data_type() {
&DataType::Int8 => {
let b = PrimitiveGroupValueBuilder::<Int8Type>::new(nullable);
v.push(Box::new(b) as _)
}
&DataType::Int16 => {
let b = PrimitiveGroupValueBuilder::<Int16Type>::new(nullable);
v.push(Box::new(b) as _)
}
&DataType::Int32 => {
let b = PrimitiveGroupValueBuilder::<Int32Type>::new(nullable);
v.push(Box::new(b) as _)
}
&DataType::Int64 => {
let b = PrimitiveGroupValueBuilder::<Int64Type>::new(nullable);
v.push(Box::new(b) as _)
}
&DataType::UInt8 => {
let b = PrimitiveGroupValueBuilder::<UInt8Type>::new(nullable);
v.push(Box::new(b) as _)
}
&DataType::UInt16 => {
let b = PrimitiveGroupValueBuilder::<UInt16Type>::new(nullable);
v.push(Box::new(b) as _)
}
&DataType::UInt32 => {
let b = PrimitiveGroupValueBuilder::<UInt32Type>::new(nullable);
v.push(Box::new(b) as _)
}
&DataType::UInt64 => {
let b = PrimitiveGroupValueBuilder::<UInt64Type>::new(nullable);
v.push(Box::new(b) as _)
}
&DataType::Int8 => instantiate_primitive!(v, nullable, Int8Type),
&DataType::Int16 => instantiate_primitive!(v, nullable, Int16Type),
&DataType::Int32 => instantiate_primitive!(v, nullable, Int32Type),
&DataType::Int64 => instantiate_primitive!(v, nullable, Int64Type),
&DataType::UInt8 => instantiate_primitive!(v, nullable, UInt8Type),
&DataType::UInt16 => instantiate_primitive!(v, nullable, UInt16Type),
&DataType::UInt32 => instantiate_primitive!(v, nullable, UInt32Type),
&DataType::UInt64 => instantiate_primitive!(v, nullable, UInt64Type),
&DataType::Float32 => {
let b = PrimitiveGroupValueBuilder::<Float32Type>::new(nullable);
v.push(Box::new(b) as _)
instantiate_primitive!(v, nullable, Float32Type)
}
&DataType::Float64 => {
let b = PrimitiveGroupValueBuilder::<Float64Type>::new(nullable);
v.push(Box::new(b) as _)
}
&DataType::Date32 => {
let b = PrimitiveGroupValueBuilder::<Date32Type>::new(nullable);
v.push(Box::new(b) as _)
}
&DataType::Date64 => {
let b = PrimitiveGroupValueBuilder::<Date64Type>::new(nullable);
v.push(Box::new(b) as _)
instantiate_primitive!(v, nullable, Float64Type)
}
&DataType::Date32 => instantiate_primitive!(v, nullable, Date32Type),
&DataType::Date64 => instantiate_primitive!(v, nullable, Date64Type),
&DataType::Utf8 => {
let b = ByteGroupValueBuilder::<i32>::new(OutputType::Utf8);
v.push(Box::new(b) as _)
Expand Down
120 changes: 74 additions & 46 deletions datafusion/physical-plan/src/aggregates/group_values/group_column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,62 +62,96 @@ pub trait GroupColumn: Send + Sync {
fn take_n(&mut self, n: usize) -> ArrayRef;
}

/// Stores a collection of primitive group values which are known to have no nulls
#[derive(Debug)]
pub struct NonNullPrimitiveGroupValueBuilder<T: ArrowPrimitiveType> {
group_values: Vec<T::Native>,
}

impl<T> NonNullPrimitiveGroupValueBuilder<T>
where
T: ArrowPrimitiveType,
{
pub fn new() -> Self {
Self {
group_values: vec![],
}
}
}

impl<T: ArrowPrimitiveType> GroupColumn for NonNullPrimitiveGroupValueBuilder<T> {
fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool {
// know input has no nulls
self.group_values[lhs_row] == array.as_primitive::<T>().value(rhs_row)
}

fn append_val(&mut self, array: &ArrayRef, row: usize) {
// input can't possibly have nulls, so don't worry about them
self.group_values.push(array.as_primitive::<T>().value(row))
}

fn len(&self) -> usize {
self.group_values.len()
}

fn size(&self) -> usize {
self.group_values.allocated_size()
}

fn build(self: Box<Self>) -> ArrayRef {
let Self { group_values } = *self;

let nulls = None;

Arc::new(PrimitiveArray::<T>::new(
ScalarBuffer::from(group_values),
nulls,
))
}

fn take_n(&mut self, n: usize) -> ArrayRef {
let first_n = self.group_values.drain(0..n).collect::<Vec<_>>();
let first_n_nulls = None;

Arc::new(PrimitiveArray::<T>::new(
ScalarBuffer::from(first_n),
first_n_nulls,
))
}
}

/// Stores a collection of primitive group values which may have nulls
#[derive(Debug)]
pub struct PrimitiveGroupValueBuilder<T: ArrowPrimitiveType> {
group_values: Vec<T::Native>,
/// Null state (when None, input is guaranteed not to have nulls)
nulls: Option<MaybeNullBufferBuilder>,
nulls: MaybeNullBufferBuilder,
}

impl<T> PrimitiveGroupValueBuilder<T>
where
T: ArrowPrimitiveType,
{
/// Create a new [`PrimitiveGroupValueBuilder`]
///
/// If `nullable` is false, it means the input will never have nulls
pub fn new(nullable: bool) -> Self {
let nulls = if nullable {
Some(MaybeNullBufferBuilder::new())
} else {
None
};

pub fn new() -> Self {
Self {
group_values: vec![],
nulls,
nulls: MaybeNullBufferBuilder::new(),
}
}
}

impl<T: ArrowPrimitiveType> GroupColumn for PrimitiveGroupValueBuilder<T> {
fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool {
// fast path when input has no nulls
match self.nulls.as_ref() {
None => {
self.group_values[lhs_row] == array.as_primitive::<T>().value(rhs_row)
}
Some(nulls) => {
// slower path if the input could have nulls
nulls.is_null(lhs_row) == array.is_null(rhs_row)
&& self.group_values[lhs_row]
== array.as_primitive::<T>().value(rhs_row)
}
}
self.nulls.is_null(lhs_row) == array.is_null(rhs_row)
&& self.group_values[lhs_row] == array.as_primitive::<T>().value(rhs_row)
}

fn append_val(&mut self, array: &ArrayRef, row: usize) {
match self.nulls.as_mut() {
// input can't possibly have nulls, so don't worry about them
None => self.group_values.push(array.as_primitive::<T>().value(row)),
Some(nulls) => {
if array.is_null(row) {
nulls.append(true);
self.group_values.push(T::default_value());
} else {
nulls.append(false);
self.group_values.push(array.as_primitive::<T>().value(row));
}
}
if array.is_null(row) {
self.nulls.append(true);
self.group_values.push(T::default_value());
} else {
self.nulls.append(false);
self.group_values.push(array.as_primitive::<T>().value(row));
}
}

Expand All @@ -126,13 +160,7 @@ impl<T: ArrowPrimitiveType> GroupColumn for PrimitiveGroupValueBuilder<T> {
}

fn size(&self) -> usize {
let nulls_size = self
.nulls
.as_ref()
.map(|nulls| nulls.allocated_size())
.unwrap_or(0);

self.group_values.allocated_size() + nulls_size
self.group_values.allocated_size() + self.nulls.allocated_size()
}

fn build(self: Box<Self>) -> ArrayRef {
Expand All @@ -141,7 +169,7 @@ impl<T: ArrowPrimitiveType> GroupColumn for PrimitiveGroupValueBuilder<T> {
nulls,
} = *self;

let nulls = nulls.and_then(|nulls| nulls.build());
let nulls = nulls.build();

Arc::new(PrimitiveArray::<T>::new(
ScalarBuffer::from(group_values),
Expand All @@ -151,7 +179,7 @@ impl<T: ArrowPrimitiveType> GroupColumn for PrimitiveGroupValueBuilder<T> {

fn take_n(&mut self, n: usize) -> ArrayRef {
let first_n = self.group_values.drain(0..n).collect::<Vec<_>>();
let first_n_nulls = self.nulls.as_mut().and_then(|nulls| nulls.take_n(n));
let first_n_nulls = self.nulls.take_n(n);

Arc::new(PrimitiveArray::<T>::new(
ScalarBuffer::from(first_n),
Expand Down

0 comments on commit 36a2003

Please sign in to comment.