diff --git a/arrow-array/src/array/primitive_array.rs b/arrow-array/src/array/primitive_array.rs index f9e4e7675da2..4ac191ac977b 100644 --- a/arrow-array/src/array/primitive_array.rs +++ b/arrow-array/src/array/primitive_array.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::builder::{BooleanBufferBuilder, PrimitiveBuilder}; +use crate::builder::{BooleanBufferBuilder, BufferBuilder, PrimitiveBuilder}; use crate::iterator::PrimitiveIter; use crate::raw_pointer::RawPtrBox; use crate::temporal_conversions::{as_date, as_datetime, as_duration, as_time}; @@ -23,6 +23,7 @@ use crate::trusted_len::trusted_len_unzip; use crate::types::*; use crate::{print_long_array, Array, ArrayAccessor}; use arrow_buffer::{bit_util, ArrowNativeType, Buffer, MutableBuffer}; +use arrow_data::bit_iterator::try_for_each_valid_idx; use arrow_data::ArrayData; use arrow_schema::DataType; use chrono::{Duration, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime}; @@ -298,20 +299,10 @@ impl PrimitiveArray { /// Creates a PrimitiveArray based on a constant value with `count` elements pub fn from_value(value: T::Native, count: usize) -> Self { - // # Safety: iterator (0..count) correctly reports its length - let val_buf = unsafe { Buffer::from_trusted_len_iter((0..count).map(|_| value)) }; - let data = unsafe { - ArrayData::new_unchecked( - T::DATA_TYPE, - val_buf.len() / std::mem::size_of::<::Native>(), - None, - None, - 0, - vec![val_buf], - vec![], - ) - }; - PrimitiveArray::from(data) + unsafe { + let val_buf = Buffer::from_trusted_len_iter((0..count).map(|_| value)); + build_primitive_array(count, val_buf, 0, None) + } } /// Returns an iterator that returns the values of `array.value(i)` for an iterator with each element `i` @@ -332,6 +323,104 @@ impl PrimitiveArray { ) -> impl Iterator> + 'a { indexes.map(|opt_index| opt_index.map(|index| self.value_unchecked(index))) } + + /// Applies an unary and infallible function to a primitive array. + /// This is the fastest way to perform an operation on a primitive array when + /// the benefits of a vectorized operation outweigh the cost of branching nulls and non-nulls. + /// + /// # Implementation + /// + /// This will apply the function for all values, including those on null slots. + /// This implies that the operation must be infallible for any value of the corresponding type + /// or this function may panic. + /// # Example + /// ```rust + /// # use arrow_array::{Int32Array, types::Int32Type}; + /// # fn main() { + /// let array = Int32Array::from(vec![Some(5), Some(7), None]); + /// let c = array.unary(|x| x * 2 + 1); + /// assert_eq!(c, Int32Array::from(vec![Some(11), Some(15), None])); + /// # } + /// ``` + pub fn unary(&self, op: F) -> PrimitiveArray + where + O: ArrowPrimitiveType, + F: Fn(T::Native) -> O::Native, + { + let data = self.data(); + let len = self.len(); + let null_count = self.null_count(); + + let null_buffer = data.null_buffer().map(|b| b.bit_slice(data.offset(), len)); + let values = self.values().iter().map(|v| op(*v)); + // JUSTIFICATION + // Benefit + // ~60% speedup + // Soundness + // `values` is an iterator with a known size because arrays are sized. + let buffer = unsafe { Buffer::from_trusted_len_iter(values) }; + unsafe { build_primitive_array(len, buffer, null_count, null_buffer) } + } + + /// Applies a unary and fallible function to all valid values in a primitive array + /// + /// This is unlike [`Self::unary`] which will apply an infallible function to all rows + /// regardless of validity, in many cases this will be significantly faster and should + /// be preferred if `op` is infallible. + /// + /// Note: LLVM is currently unable to effectively vectorize fallible operations + pub fn try_unary(&self, op: F) -> Result, E> + where + O: ArrowPrimitiveType, + F: Fn(T::Native) -> Result, + { + let data = self.data(); + let len = self.len(); + let null_count = self.null_count(); + + if null_count == 0 { + let values = self.values().iter().map(|v| op(*v)); + // JUSTIFICATION + // Benefit + // ~60% speedup + // Soundness + // `values` is an iterator with a known size because arrays are sized. + let buffer = unsafe { Buffer::try_from_trusted_len_iter(values)? }; + return Ok(unsafe { build_primitive_array(len, buffer, 0, None) }); + } + + let null_buffer = data.null_buffer().map(|b| b.bit_slice(data.offset(), len)); + let mut buffer = BufferBuilder::::new(len); + buffer.append_n_zeroed(len); + let slice = buffer.as_slice_mut(); + + try_for_each_valid_idx(len, 0, null_count, null_buffer.as_deref(), |idx| { + unsafe { *slice.get_unchecked_mut(idx) = op(self.value_unchecked(idx))? }; + Ok::<_, E>(()) + })?; + + Ok(unsafe { + build_primitive_array(len, buffer.finish(), null_count, null_buffer) + }) + } +} + +#[inline] +unsafe fn build_primitive_array( + len: usize, + buffer: Buffer, + null_count: usize, + null_buffer: Option, +) -> PrimitiveArray { + PrimitiveArray::from(ArrayData::new_unchecked( + O::DATA_TYPE, + len, + Some(null_count), + null_buffer, + 0, + vec![buffer], + vec![], + )) } impl From> for ArrayData { diff --git a/arrow/src/compute/kernels/arity.rs b/arrow/src/compute/kernels/arity.rs index 5f875e6ddf29..cb5184c0e9d4 100644 --- a/arrow/src/compute/kernels/arity.rs +++ b/arrow/src/compute/kernels/arity.rs @@ -48,92 +48,24 @@ unsafe fn build_primitive_array( )) } -/// Applies an unary and infallible function to a primitive array. -/// This is the fastest way to perform an operation on a primitive array when -/// the benefits of a vectorized operation outweigh the cost of branching nulls and non-nulls. -/// -/// # Implementation -/// -/// This will apply the function for all values, including those on null slots. -/// This implies that the operation must be infallible for any value of the corresponding type -/// or this function may panic. -/// # Example -/// ```rust -/// # use arrow::array::Int32Array; -/// # use arrow::datatypes::Int32Type; -/// # use arrow::compute::kernels::arity::unary; -/// # fn main() { -/// let array = Int32Array::from(vec![Some(5), Some(7), None]); -/// let c = unary::<_, _, Int32Type>(&array, |x| x * 2 + 1); -/// assert_eq!(c, Int32Array::from(vec![Some(11), Some(15), None])); -/// # } -/// ``` +/// See [`PrimitiveArray::unary`] pub fn unary(array: &PrimitiveArray, op: F) -> PrimitiveArray where I: ArrowPrimitiveType, O: ArrowPrimitiveType, F: Fn(I::Native) -> O::Native, { - let data = array.data(); - let len = data.len(); - let null_count = data.null_count(); - - let null_buffer = data - .null_buffer() - .map(|b| b.bit_slice(data.offset(), data.len())); - - let values = array.values().iter().map(|v| op(*v)); - // JUSTIFICATION - // Benefit - // ~60% speedup - // Soundness - // `values` is an iterator with a known size because arrays are sized. - let buffer = unsafe { Buffer::from_trusted_len_iter(values) }; - unsafe { build_primitive_array(len, buffer, null_count, null_buffer) } + array.unary(op) } -/// Applies a unary and fallible function to all valid values in a primitive array -/// -/// This is unlike [`unary`] which will apply an infallible function to all rows regardless -/// of validity, in many cases this will be significantly faster and should be preferred -/// if `op` is infallible. -/// -/// Note: LLVM is currently unable to effectively vectorize fallible operations +/// See [`PrimitiveArray::try_unary`] pub fn try_unary(array: &PrimitiveArray, op: F) -> Result> where I: ArrowPrimitiveType, O: ArrowPrimitiveType, F: Fn(I::Native) -> Result, { - let len = array.len(); - let null_count = array.null_count(); - - if null_count == 0 { - let values = array.values().iter().map(|v| op(*v)); - // JUSTIFICATION - // Benefit - // ~60% speedup - // Soundness - // `values` is an iterator with a known size because arrays are sized. - let buffer = unsafe { Buffer::try_from_trusted_len_iter(values)? }; - return Ok(unsafe { build_primitive_array(len, buffer, 0, None) }); - } - - let null_buffer = array - .data_ref() - .null_buffer() - .map(|b| b.bit_slice(array.offset(), array.len())); - - let mut buffer = BufferBuilder::::new(len); - buffer.append_n_zeroed(array.len()); - let slice = buffer.as_slice_mut(); - - try_for_each_valid_idx(array.len(), 0, null_count, null_buffer.as_deref(), |idx| { - unsafe { *slice.get_unchecked_mut(idx) = op(array.value_unchecked(idx))? }; - Ok::<_, ArrowError>(()) - })?; - - Ok(unsafe { build_primitive_array(len, buffer.finish(), null_count, null_buffer) }) + array.try_unary(op) } /// A helper function that applies an infallible unary function to a dictionary array with primitive value type.