Skip to content

Commit

Permalink
Use downcast_dictionary_array in unary_dyn (#2663)
Browse files Browse the repository at this point in the history
* Use downcast_dictionary_array in unary_dyn

* Further cleanups

* Clippy
  • Loading branch information
tustvold authored Sep 6, 2022
1 parent 463240a commit 0c85233
Showing 1 changed file with 23 additions and 102 deletions.
125 changes: 23 additions & 102 deletions arrow/src/compute/kernels/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@
use crate::array::{Array, ArrayData, ArrayRef, DictionaryArray, PrimitiveArray};
use crate::buffer::Buffer;
use crate::datatypes::{
ArrowNumericType, ArrowPrimitiveType, DataType, Int16Type, Int32Type, Int64Type,
Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
};
use crate::datatypes::{ArrowNumericType, ArrowPrimitiveType};
use crate::downcast_dictionary_array;
use crate::error::{ArrowError, Result};
use std::sync::Arc;

Expand All @@ -31,14 +29,13 @@ fn into_primitive_array_data<I: ArrowPrimitiveType, O: ArrowPrimitiveType>(
array: &PrimitiveArray<I>,
buffer: Buffer,
) -> ArrayData {
let data = array.data();
unsafe {
ArrayData::new_unchecked(
O::DATA_TYPE,
array.len(),
None,
array
.data_ref()
.null_buffer()
Some(data.null_count()),
data.null_buffer()
.map(|b| b.bit_slice(array.offset(), array.len())),
0,
vec![buffer],
Expand Down Expand Up @@ -84,39 +81,15 @@ where
}

/// A helper function that applies an unary function to a dictionary array with primitive value type.
#[allow(clippy::redundant_closure)]
fn unary_dict<K, F, T>(array: &DictionaryArray<K>, op: F) -> Result<ArrayRef>
where
K: ArrowNumericType,
T: ArrowPrimitiveType,
F: Fn(T::Native) -> T::Native,
{
let dict_values = array
.values()
.as_any()
.downcast_ref::<PrimitiveArray<T>>()
.unwrap();

let values = dict_values
.iter()
.map(|v| v.map(|value| op(value)))
.collect::<PrimitiveArray<T>>();

let keys = array.keys();

let mut data = ArrayData::builder(array.data_type().clone())
.len(keys.len())
.add_buffer(keys.data().buffers()[0].clone())
.add_child_data(values.data().clone());

match keys.data().null_buffer() {
Some(buffer) if keys.data().null_count() > 0 => {
data = data
.null_bit_buffer(Some(buffer.clone()))
.null_count(keys.data().null_count());
}
_ => data = data.null_count(0),
}
let dict_values = array.values().as_any().downcast_ref().unwrap();
let values = unary::<T, F, T>(dict_values, op).into_data();
let data = array.data().clone().into_builder().child_data(vec![values]);

let new_dict: DictionaryArray<K> = unsafe { data.build_unchecked() }.into();
Ok(Arc::new(new_dict))
Expand All @@ -128,73 +101,21 @@ where
T: ArrowPrimitiveType,
F: Fn(T::Native) -> T::Native,
{
match array.data_type() {
DataType::Dictionary(key_type, _) => match key_type.as_ref() {
DataType::Int8 => unary_dict::<_, F, T>(
array
.as_any()
.downcast_ref::<DictionaryArray<Int8Type>>()
.unwrap(),
op,
),
DataType::Int16 => unary_dict::<_, F, T>(
array
.as_any()
.downcast_ref::<DictionaryArray<Int16Type>>()
.unwrap(),
op,
),
DataType::Int32 => unary_dict::<_, F, T>(
array
.as_any()
.downcast_ref::<DictionaryArray<Int32Type>>()
.unwrap(),
op,
),
DataType::Int64 => unary_dict::<_, F, T>(
array
.as_any()
.downcast_ref::<DictionaryArray<Int64Type>>()
.unwrap(),
op,
),
DataType::UInt8 => unary_dict::<_, F, T>(
array
.as_any()
.downcast_ref::<DictionaryArray<UInt8Type>>()
.unwrap(),
op,
),
DataType::UInt16 => unary_dict::<_, F, T>(
array
.as_any()
.downcast_ref::<DictionaryArray<UInt16Type>>()
.unwrap(),
op,
),
DataType::UInt32 => unary_dict::<_, F, T>(
array
.as_any()
.downcast_ref::<DictionaryArray<UInt32Type>>()
.unwrap(),
op,
),
DataType::UInt64 => unary_dict::<_, F, T>(
array
.as_any()
.downcast_ref::<DictionaryArray<UInt64Type>>()
.unwrap(),
op,
),
t => Err(ArrowError::NotYetImplemented(format!(
"Cannot perform unary operation on dictionary array of key type {}.",
t
))),
},
_ => Ok(Arc::new(unary::<T, F, T>(
array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
op,
))),
downcast_dictionary_array! {
array => unary_dict::<_, F, T>(array, op),
t => {
if t == &T::DATA_TYPE {
Ok(Arc::new(unary::<T, F, T>(
array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
op,
)))
} else {
Err(ArrowError::NotYetImplemented(format!(
"Cannot perform unary operation on array of type {}",
t
)))
}
}
}
}

Expand Down

0 comments on commit 0c85233

Please sign in to comment.