Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move unary kernels to arrow-array (#2787) #2789

Merged
merged 2 commits into from
Sep 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 104 additions & 15 deletions arrow-array/src/array/primitive_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
// specific language governing permissions and limitations
// under the License.

use crate::builder::{BooleanBufferBuilder, PrimitiveBuilder};
use crate::builder::{BooleanBufferBuilder, BufferBuilder, PrimitiveBuilder};
use crate::iterator::PrimitiveIter;
use crate::raw_pointer::RawPtrBox;
use crate::temporal_conversions::{as_date, as_datetime, as_duration, as_time};
use crate::trusted_len::trusted_len_unzip;
use crate::types::*;
use crate::{print_long_array, Array, ArrayAccessor};
use arrow_buffer::{bit_util, ArrowNativeType, Buffer, MutableBuffer};
use arrow_data::bit_iterator::try_for_each_valid_idx;
use arrow_data::ArrayData;
use arrow_schema::DataType;
use chrono::{Duration, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime};
Expand Down Expand Up @@ -298,20 +299,10 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {

/// Creates a PrimitiveArray based on a constant value with `count` elements
pub fn from_value(value: T::Native, count: usize) -> Self {
// # Safety: iterator (0..count) correctly reports its length
let val_buf = unsafe { Buffer::from_trusted_len_iter((0..count).map(|_| value)) };
let data = unsafe {
ArrayData::new_unchecked(
T::DATA_TYPE,
val_buf.len() / std::mem::size_of::<<T as ArrowPrimitiveType>::Native>(),
None,
None,
0,
vec![val_buf],
vec![],
)
};
PrimitiveArray::from(data)
unsafe {
let val_buf = Buffer::from_trusted_len_iter((0..count).map(|_| value));
build_primitive_array(count, val_buf, 0, None)
}
}

/// Returns an iterator that returns the values of `array.value(i)` for an iterator with each element `i`
Expand All @@ -332,6 +323,104 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
) -> impl Iterator<Item = Option<T::Native>> + 'a {
indexes.map(|opt_index| opt_index.map(|index| self.value_unchecked(index)))
}

/// Applies an unary and infallible function to a primitive array.
/// This is the fastest way to perform an operation on a primitive array when
/// the benefits of a vectorized operation outweigh the cost of branching nulls and non-nulls.
///
/// # Implementation
///
/// This will apply the function for all values, including those on null slots.
/// This implies that the operation must be infallible for any value of the corresponding type
/// or this function may panic.
/// # Example
/// ```rust
/// # use arrow_array::{Int32Array, types::Int32Type};
/// # fn main() {
/// let array = Int32Array::from(vec![Some(5), Some(7), None]);
/// let c = array.unary(|x| x * 2 + 1);
/// assert_eq!(c, Int32Array::from(vec![Some(11), Some(15), None]));
/// # }
/// ```
pub fn unary<F, O>(&self, op: F) -> PrimitiveArray<O>
where
O: ArrowPrimitiveType,
F: Fn(T::Native) -> O::Native,
{
let data = self.data();
let len = self.len();
let null_count = self.null_count();

let null_buffer = data.null_buffer().map(|b| b.bit_slice(data.offset(), len));
let values = self.values().iter().map(|v| op(*v));
// JUSTIFICATION
// Benefit
// ~60% speedup
// Soundness
// `values` is an iterator with a known size because arrays are sized.
let buffer = unsafe { Buffer::from_trusted_len_iter(values) };
unsafe { build_primitive_array(len, buffer, null_count, null_buffer) }
}

/// Applies a unary and fallible function to all valid values in a primitive array
///
/// This is unlike [`Self::unary`] which will apply an infallible function to all rows
/// regardless of validity, in many cases this will be significantly faster and should
/// be preferred if `op` is infallible.
///
/// Note: LLVM is currently unable to effectively vectorize fallible operations
pub fn try_unary<F, O, E>(&self, op: F) -> Result<PrimitiveArray<O>, E>
where
O: ArrowPrimitiveType,
F: Fn(T::Native) -> Result<O::Native, E>,
{
let data = self.data();
let len = self.len();
let null_count = self.null_count();

if null_count == 0 {
let values = self.values().iter().map(|v| op(*v));
// JUSTIFICATION
// Benefit
// ~60% speedup
// Soundness
// `values` is an iterator with a known size because arrays are sized.
let buffer = unsafe { Buffer::try_from_trusted_len_iter(values)? };
return Ok(unsafe { build_primitive_array(len, buffer, 0, None) });
}

let null_buffer = data.null_buffer().map(|b| b.bit_slice(data.offset(), len));
let mut buffer = BufferBuilder::<O::Native>::new(len);
buffer.append_n_zeroed(len);
let slice = buffer.as_slice_mut();

try_for_each_valid_idx(len, 0, null_count, null_buffer.as_deref(), |idx| {
unsafe { *slice.get_unchecked_mut(idx) = op(self.value_unchecked(idx))? };
Ok::<_, E>(())
})?;

Ok(unsafe {
build_primitive_array(len, buffer.finish(), null_count, null_buffer)
})
}
}

#[inline]
unsafe fn build_primitive_array<O: ArrowPrimitiveType>(
len: usize,
buffer: Buffer,
null_count: usize,
null_buffer: Option<Buffer>,
) -> PrimitiveArray<O> {
PrimitiveArray::from(ArrayData::new_unchecked(
O::DATA_TYPE,
len,
Some(null_count),
null_buffer,
0,
vec![buffer],
vec![],
))
}

impl<T: ArrowPrimitiveType> From<PrimitiveArray<T>> for ArrayData {
Expand Down
76 changes: 4 additions & 72 deletions arrow/src/compute/kernels/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,92 +48,24 @@ unsafe fn build_primitive_array<O: ArrowPrimitiveType>(
))
}

/// Applies an unary and infallible function to a primitive array.
/// This is the fastest way to perform an operation on a primitive array when
/// the benefits of a vectorized operation outweigh the cost of branching nulls and non-nulls.
///
/// # Implementation
///
/// This will apply the function for all values, including those on null slots.
/// This implies that the operation must be infallible for any value of the corresponding type
/// or this function may panic.
/// # Example
/// ```rust
/// # use arrow::array::Int32Array;
/// # use arrow::datatypes::Int32Type;
/// # use arrow::compute::kernels::arity::unary;
/// # fn main() {
/// let array = Int32Array::from(vec![Some(5), Some(7), None]);
/// let c = unary::<_, _, Int32Type>(&array, |x| x * 2 + 1);
/// assert_eq!(c, Int32Array::from(vec![Some(11), Some(15), None]));
/// # }
/// ```
/// See [`PrimitiveArray::unary`]
pub fn unary<I, F, O>(array: &PrimitiveArray<I>, op: F) -> PrimitiveArray<O>
where
I: ArrowPrimitiveType,
O: ArrowPrimitiveType,
F: Fn(I::Native) -> O::Native,
{
let data = array.data();
let len = data.len();
let null_count = data.null_count();

let null_buffer = data
.null_buffer()
.map(|b| b.bit_slice(data.offset(), data.len()));

let values = array.values().iter().map(|v| op(*v));
// JUSTIFICATION
// Benefit
// ~60% speedup
// Soundness
// `values` is an iterator with a known size because arrays are sized.
let buffer = unsafe { Buffer::from_trusted_len_iter(values) };
unsafe { build_primitive_array(len, buffer, null_count, null_buffer) }
array.unary(op)
}

/// Applies a unary and fallible function to all valid values in a primitive array
///
/// This is unlike [`unary`] which will apply an infallible function to all rows regardless
/// of validity, in many cases this will be significantly faster and should be preferred
/// if `op` is infallible.
///
/// Note: LLVM is currently unable to effectively vectorize fallible operations
/// See [`PrimitiveArray::try_unary`]
pub fn try_unary<I, F, O>(array: &PrimitiveArray<I>, op: F) -> Result<PrimitiveArray<O>>
where
I: ArrowPrimitiveType,
O: ArrowPrimitiveType,
F: Fn(I::Native) -> Result<O::Native>,
{
let len = array.len();
let null_count = array.null_count();

if null_count == 0 {
let values = array.values().iter().map(|v| op(*v));
// JUSTIFICATION
// Benefit
// ~60% speedup
// Soundness
// `values` is an iterator with a known size because arrays are sized.
let buffer = unsafe { Buffer::try_from_trusted_len_iter(values)? };
return Ok(unsafe { build_primitive_array(len, buffer, 0, None) });
}

let null_buffer = array
.data_ref()
.null_buffer()
.map(|b| b.bit_slice(array.offset(), array.len()));

let mut buffer = BufferBuilder::<O::Native>::new(len);
buffer.append_n_zeroed(array.len());
let slice = buffer.as_slice_mut();

try_for_each_valid_idx(array.len(), 0, null_count, null_buffer.as_deref(), |idx| {
unsafe { *slice.get_unchecked_mut(idx) = op(array.value_unchecked(idx))? };
Ok::<_, ArrowError>(())
})?;

Ok(unsafe { build_primitive_array(len, buffer.finish(), null_count, null_buffer) })
array.try_unary(op)
}

/// A helper function that applies an infallible unary function to a dictionary array with primitive value type.
Expand Down