Skip to content

Commit

Permalink
Add PrimitiveArray::new (apache#3879)
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Mar 22, 2023
1 parent dc23fa3 commit 82c99bd
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 59 deletions.
25 changes: 5 additions & 20 deletions arrow-arith/src/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,11 @@ use arrow_array::iterator::ArrayIter;
use arrow_array::types::ArrowDictionaryKeyType;
use arrow_array::*;
use arrow_buffer::buffer::NullBuffer;
use arrow_buffer::{Buffer, MutableBuffer};
use arrow_buffer::{Buffer, MutableBuffer, ScalarBuffer};
use arrow_data::{ArrayData, ArrayDataBuilder};
use arrow_schema::ArrowError;
use std::sync::Arc;

#[inline]
unsafe fn build_primitive_array<O: ArrowPrimitiveType>(
len: usize,
buffer: Buffer,
nulls: Option<NullBuffer>,
) -> PrimitiveArray<O> {
PrimitiveArray::from(
ArrayDataBuilder::new(O::DATA_TYPE)
.len(len)
.nulls(nulls)
.buffers(vec![buffer])
.build_unchecked(),
)
}

/// See [`PrimitiveArray::unary`]
pub fn unary<I, F, O>(array: &PrimitiveArray<I>, op: F) -> PrimitiveArray<O>
where
Expand Down Expand Up @@ -224,8 +209,7 @@ where
// Soundness
// `values` is an iterator with a known size from a PrimitiveArray
let buffer = unsafe { Buffer::from_trusted_len_iter(values) };

Ok(unsafe { build_primitive_array(len, buffer, nulls) })
Ok(PrimitiveArray::new(O::DATA_TYPE, buffer.into(), nulls))
}

/// Given two arrays of length `len`, calls `op(a[i], b[i])` for `i` in `0..len`, mutating
Expand Down Expand Up @@ -328,7 +312,8 @@ where
Ok::<_, ArrowError>(())
})?;

Ok(unsafe { build_primitive_array(len, buffer.finish(), Some(nulls)) })
let values = buffer.finish().into();
Ok(PrimitiveArray::new(O::DATA_TYPE, values, Some(nulls)))
}
}

Expand Down Expand Up @@ -412,7 +397,7 @@ where
buffer.push_unchecked(op(a.value_unchecked(idx), b.value_unchecked(idx))?);
};
}
Ok(unsafe { build_primitive_array(len, buffer.into(), None) })
Ok(PrimitiveArray::new(O::DATA_TYPE, buffer.into(), nulls))
}

/// This intentional inline(never) attribute helps LLVM optimize the loop.
Expand Down
96 changes: 57 additions & 39 deletions arrow-array/src/array/primitive_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use arrow_buffer::{
i256, ArrowNativeType, BooleanBuffer, Buffer, NullBuffer, ScalarBuffer,
};
use arrow_data::bit_iterator::try_for_each_valid_idx;
use arrow_data::{ArrayData, ArrayDataBuilder};
use arrow_data::ArrayData;
use arrow_schema::{ArrowError, DataType};
use chrono::{DateTime, Duration, NaiveDate, NaiveDateTime, NaiveTime};
use half::f16;
Expand Down Expand Up @@ -251,19 +251,58 @@ pub struct PrimitiveArray<T: ArrowPrimitiveType> {
/// Underlying ArrayData
data: ArrayData,
/// Values data
raw_values: ScalarBuffer<T::Native>,
values: ScalarBuffer<T::Native>,
}

impl<T: ArrowPrimitiveType> Clone for PrimitiveArray<T> {
fn clone(&self) -> Self {
Self {
data: self.data.clone(),
raw_values: self.raw_values.clone(),
values: self.values.clone(),
}
}
}

impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
/// Create a new [`PrimitiveArray`] from the provided data_type, values, nulls
///
/// # Panics
///
/// Panics if ]
/// - `values.len() != nulls.len()`
/// - `!Self::is_compatible(data_type)`
pub fn new(
data_type: DataType,
values: ScalarBuffer<T::Native>,
nulls: Option<NullBuffer>,
) -> Self {
Self::assert_compatible(&data_type);
if let Some(n) = nulls.as_ref() {
assert_eq!(values.len(), n.len());
}

// TODO: Don't store ArrayData inside arrays (#3880)
let data = unsafe {
ArrayData::builder(data_type)
.len(values.len())
.nulls(nulls)
.buffers(vec![values.inner().clone()])
.build_unchecked()
};

Self { data, values }
}

/// Asserts that `data_type` is compatible with `Self`
fn assert_compatible(data_type: &DataType) {
assert!(
Self::is_compatible(data_type),
"PrimitiveArray expected ArrayData with type {} got {}",
T::DATA_TYPE,
data_type
);
}

/// Returns the length of this array.
#[inline]
pub fn len(&self) -> usize {
Expand All @@ -278,7 +317,7 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
/// Returns the values of this array
#[inline]
pub fn values(&self) -> &ScalarBuffer<T::Native> {
&self.raw_values
&self.values
}

/// Returns a new primitive array builder
Expand Down Expand Up @@ -308,7 +347,7 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
/// caller must ensure that the passed in offset is less than the array len()
#[inline]
pub unsafe fn value_unchecked(&self, i: usize) -> T::Native {
*self.raw_values.get_unchecked(i)
*self.values.get_unchecked(i)
}

/// Returns the primitive value at index `i`.
Expand Down Expand Up @@ -346,7 +385,7 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
pub fn from_value(value: T::Native, count: usize) -> Self {
unsafe {
let val_buf = Buffer::from_trusted_len_iter((0..count).map(|_| value));
build_primitive_array(count, val_buf, None)
Self::new(T::DATA_TYPE, val_buf.into(), None)
}
}

Expand Down Expand Up @@ -422,7 +461,6 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
F: Fn(T::Native) -> O::Native,
{
let data = self.data();
let len = self.len();

let nulls = data.nulls().cloned();
let values = self.values().iter().map(|v| op(*v));
Expand All @@ -432,7 +470,7 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
// Soundness
// `values` is an iterator with a known size because arrays are sized.
let buffer = unsafe { Buffer::from_trusted_len_iter(values) };
unsafe { build_primitive_array(len, buffer, nulls) }
PrimitiveArray::new(O::DATA_TYPE, buffer.into(), nulls)
}

/// Applies an unary and infallible function to a mutable primitive array.
Expand Down Expand Up @@ -495,7 +533,8 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
None => (0..len).try_for_each(f)?,
}

Ok(unsafe { build_primitive_array(len, buffer.finish(), nulls) })
let values = buffer.finish().into();
Ok(PrimitiveArray::new(O::DATA_TYPE, values, nulls))
}

/// Applies an unary and fallible function to all valid values in a mutable primitive array.
Expand Down Expand Up @@ -579,13 +618,9 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
});

let nulls = BooleanBuffer::new(null_builder.finish(), 0, len);
unsafe {
build_primitive_array(
len,
buffer.finish(),
Some(NullBuffer::new_unchecked(nulls, out_null_count)),
)
}
let values = buffer.finish().into();
let nulls = unsafe { NullBuffer::new_unchecked(nulls, out_null_count) };
PrimitiveArray::new(O::DATA_TYPE, values, Some(nulls))
}

/// Returns `PrimitiveBuilder` of this primitive array for mutating its values if the underlying
Expand All @@ -599,7 +634,7 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
.slice_with_length(self.data.offset() * element_len, len * element_len);

drop(self.data);
drop(self.raw_values);
drop(self.values);

let try_mutable_null_buffer = match null_bit_buffer {
None => Ok(None),
Expand Down Expand Up @@ -647,21 +682,6 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
}
}

#[inline]
unsafe fn build_primitive_array<O: ArrowPrimitiveType>(
len: usize,
buffer: Buffer,
nulls: Option<NullBuffer>,
) -> PrimitiveArray<O> {
PrimitiveArray::from(
ArrayDataBuilder::new(O::DATA_TYPE)
.len(len)
.buffers(vec![buffer])
.nulls(nulls)
.build_unchecked(),
)
}

impl<T: ArrowPrimitiveType> From<PrimitiveArray<T>> for ArrayData {
fn from(array: PrimitiveArray<T>) -> Self {
array.data
Expand Down Expand Up @@ -1052,12 +1072,7 @@ impl<T: ArrowTimestampType> PrimitiveArray<T> {
/// Constructs a `PrimitiveArray` from an array data reference.
impl<T: ArrowPrimitiveType> From<ArrayData> for PrimitiveArray<T> {
fn from(data: ArrayData) -> Self {
assert!(
Self::is_compatible(data.data_type()),
"PrimitiveArray expected ArrayData with type {} got {}",
T::DATA_TYPE,
data.data_type()
);
Self::assert_compatible(data.data_type());
assert_eq!(
data.buffers().len(),
1,
Expand All @@ -1066,7 +1081,10 @@ impl<T: ArrowPrimitiveType> From<ArrayData> for PrimitiveArray<T> {

let raw_values =
ScalarBuffer::new(data.buffers()[0].clone(), data.offset(), data.len());
Self { data, raw_values }
Self {
data,
values: raw_values,
}
}
}

Expand Down

0 comments on commit 82c99bd

Please sign in to comment.