Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace macro with TypedDictionaryArray in comparison kernels #2514

Merged
merged 1 commit into from
Aug 19, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 57 additions & 45 deletions arrow/src/compute/kernels/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2154,49 +2154,39 @@ macro_rules! typed_dict_compares {

/// Helper function to perform boolean lambda function on values from two dictionary arrays, this
/// version does not attempt to use SIMD explicitly (though the compiler may auto vectorize)
macro_rules! compare_dict_op {
($left: expr, $right:expr, $op:expr, $value_ty:ty) => {{
if $left.len() != $right.len() {
return Err(ArrowError::ComputeError(
"Cannot perform comparison operation on arrays of different length"
.to_string(),
));
}

// Safety justification: Since the inputs are valid Arrow arrays, all values are
// valid indexes into the dictionary (which is verified during construction)

let left_iter = unsafe {
$left
.values()
.as_any()
.downcast_ref::<$value_ty>()
.unwrap()
.take_iter_unchecked($left.keys_iter())
};
fn compare_dict_op<'a, K, V, F>(
left: TypedDictionaryArray<'a, K, V>,
right: TypedDictionaryArray<'a, K, V>,
op: F,
) -> Result<BooleanArray>
where
K: ArrowNumericType,
V: Sync + Send,
&'a V: ArrayAccessor,
F: Fn(<&V as ArrayAccessor>::Item, <&V as ArrayAccessor>::Item) -> bool,
{
if left.len() != right.len() {
return Err(ArrowError::ComputeError(
"Cannot perform comparison operation on arrays of different length"
.to_string(),
));
}

let right_iter = unsafe {
$right
.values()
.as_any()
.downcast_ref::<$value_ty>()
.unwrap()
.take_iter_unchecked($right.keys_iter())
};
let left_iter = left.into_iter();
let right_iter = right.into_iter();

let result = left_iter
.zip(right_iter)
.map(|(left_value, right_value)| {
if let (Some(left), Some(right)) = (left_value, right_value) {
Some($op(left, right))
} else {
None
}
})
.collect();
let result = left_iter
.zip(right_iter)
.map(|(left_value, right_value)| {
if let (Some(left), Some(right)) = (left_value, right_value) {
Some(op(left, right))
} else {
None
}
})
.collect();

Ok(result)
}};
Ok(result)
}

/// Perform given operation on two `DictionaryArray`s.
Expand All @@ -2208,10 +2198,14 @@ pub fn cmp_dict<K, T, F>(
) -> Result<BooleanArray>
where
K: ArrowNumericType,
T: ArrowNumericType,
T: ArrowNumericType + Sync + Send,
F: Fn(T::Native, T::Native) -> bool,
{
compare_dict_op!(left, right, op, PrimitiveArray<T>)
compare_dict_op(
left.downcast_dict::<PrimitiveArray<T>>().unwrap(),
right.downcast_dict::<PrimitiveArray<T>>().unwrap(),
op,
)
}

/// Perform the given operation on two `DictionaryArray`s which value type is
Expand All @@ -2225,7 +2219,11 @@ where
K: ArrowNumericType,
F: Fn(bool, bool) -> bool,
{
compare_dict_op!(left, right, op, BooleanArray)
compare_dict_op(
left.downcast_dict::<BooleanArray>().unwrap(),
right.downcast_dict::<BooleanArray>().unwrap(),
op,
)
}

/// Perform the given operation on two `DictionaryArray`s which value type is
Expand All @@ -2239,7 +2237,14 @@ where
K: ArrowNumericType,
F: Fn(&str, &str) -> bool,
{
compare_dict_op!(left, right, op, GenericStringArray<OffsetSize>)
compare_dict_op(
left.downcast_dict::<GenericStringArray<OffsetSize>>()
.unwrap(),
right
.downcast_dict::<GenericStringArray<OffsetSize>>()
.unwrap(),
op,
)
}

/// Perform the given operation on two `DictionaryArray`s which value type is
Expand All @@ -2253,7 +2258,14 @@ where
K: ArrowNumericType,
F: Fn(&[u8], &[u8]) -> bool,
{
compare_dict_op!(left, right, op, GenericBinaryArray<OffsetSize>)
compare_dict_op(
left.downcast_dict::<GenericBinaryArray<OffsetSize>>()
.unwrap(),
right
.downcast_dict::<GenericBinaryArray<OffsetSize>>()
.unwrap(),
op,
)
}

/// Perform `left == right` operation on two (dynamic) [`Array`]s.
Expand Down