-
Notifications
You must be signed in to change notification settings - Fork 810
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add scalar comparison kernels for DictionaryArray #984
Changes from 6 commits
c5c7f35
155d60f
4896762
0d5d1b1
c147869
b5f04c5
ee7997c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,8 +27,9 @@ use crate::buffer::{bitwise_bin_op_helper, buffer_unary_not, Buffer, MutableBuff | |
use crate::compute::binary_boolean_kernel; | ||
use crate::compute::util::combine_option_bitmap; | ||
use crate::datatypes::{ | ||
ArrowNumericType, DataType, Float32Type, Float64Type, Int16Type, Int32Type, | ||
Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, | ||
ArrowNativeType, ArrowNumericType, ArrowPrimitiveType, DataType, Float32Type, | ||
Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, | ||
UInt64Type, UInt8Type, | ||
}; | ||
use crate::error::{ArrowError, Result}; | ||
use crate::util::bit_util; | ||
|
@@ -200,6 +201,54 @@ macro_rules! compare_op_scalar_primitive { | |
}}; | ||
} | ||
|
||
macro_rules! compare_dict_op_scalar { | ||
($left:expr, $T:ident, $right:expr, $op:expr) => {{ | ||
let null_bit_buffer = $left | ||
.data() | ||
.null_buffer() | ||
.map(|b| b.bit_slice($left.offset(), $left.len())); | ||
|
||
let values = $left.values(); | ||
|
||
let array = values | ||
.as_any() | ||
.downcast_ref::<PrimitiveArray<$T>>() | ||
.unwrap(); | ||
|
||
// Safety: | ||
// `i < $left.len()` | ||
let comparison: Vec<bool> = (0..array.len()) | ||
.map(|i| unsafe { $op(array.value_unchecked(i), $right) }) | ||
.collect(); | ||
|
||
let result: Vec<bool> = (0..$left.keys().len()) | ||
.map(|key| { | ||
let index = $left.keys().value(key); | ||
comparison[index | ||
.to_usize() | ||
.expect(format!("Failed at idx {:?}", index).as_str())] | ||
}) | ||
.collect(); | ||
|
||
// same as $left.len() | ||
let buffer = | ||
unsafe { MutableBuffer::from_trusted_len_iter_bool(result.into_iter()) }; | ||
|
||
let data = unsafe { | ||
ArrayData::new_unchecked( | ||
DataType::Boolean, | ||
$left.len(), | ||
None, | ||
null_bit_buffer, | ||
0, | ||
vec![Buffer::from(buffer)], | ||
vec![], | ||
) | ||
}; | ||
Ok(BooleanArray::from(data)) | ||
}}; | ||
} | ||
|
||
/// Evaluate `op(left, right)` for [`PrimitiveArray`]s using a specified | ||
/// comparison function. | ||
pub fn no_simd_compare_op<T, F>( | ||
|
@@ -693,6 +742,14 @@ pub fn eq_utf8_scalar<OffsetSize: StringOffsetSizeTrait>( | |
compare_op_scalar!(left, right, |a, b| a == b) | ||
} | ||
|
||
/// Perform `left == right` operation on [`DictionaryArray`] and a scalar. | ||
// pub fn eq_dict<OffsetSize: ArrowPrimitiveType>( | ||
// left: &DictionaryArray<OffsetSize>, | ||
// right: &str, | ||
// ) -> Result<BooleanArray> { | ||
// compare_dict_op!(left, right, |a, b| a == b) | ||
// } | ||
|
||
#[inline] | ||
fn binary_boolean_op<F>( | ||
left: &BooleanArray, | ||
|
@@ -1200,6 +1257,29 @@ where | |
return compare_op_scalar!(left, right, |a, b| a == b); | ||
} | ||
|
||
/// Perform `left == right` operation on a [`PrimitiveArray`] and a scalar value. | ||
pub fn eq_dict_scalar<T>( | ||
left: &DictionaryArray<T>, | ||
right: T::Native, | ||
) -> Result<BooleanArray> | ||
where | ||
T: ArrowNumericType, | ||
{ | ||
#[cfg(not(feature = "simd"))] | ||
println!("{}", std::any::type_name::<T>()); | ||
return compare_dict_op_scalar!(left, T, right, |a, b| a == b); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @matthewmturner this is what I was trying to say. I think the way you have this function with a single Here is a sketch of how this might work: /// Perform `left == right` operation on a [`DictionaryArray`] and a numeric scalar value.
pub fn eq_dict_scalar<T, K>(
left: &DictionaryArray<K>,
right: T::Native,
) -> Result<BooleanArray>
where
T: ArrowNumericType,
K: ArrowNumericType,
{
// compare to the dictionary values (e.g if the dictionary is {A,
// B} and the keys are {1,0,1,1} that represents the values B, A,
// B, B.
//
// So we compare just the dictionary {A, B} values to `right` and
//
// TODO macro-ize this
let dictionary_comparison = match left.values().data_type() {
DataType::Int8 => {
eq_scalar(as_primitive_array::<T>(left.values()), right)
}
// TODO fill in Int16, Int32, etc
_ => unimplemented!("Should error: dictionary did not store values of type T")
}?;
// Required for safety below
assert_eq!(dictionary_comparison.len(), left.values().len());
// Now, look up the dictionary for each output
let result: BooleanArray = left.keys()
.iter()
.map(|key| {
// figure out how the dictionary element at this index
// compared to the scalar
key.map(|key| {
// safety: the original array's indices were valid
// `(0 .. left.values().len()` and dictionary_comparisoon
// is the same size, checked above
unsafe {
// it would be nice to avoid checking the conversion each time
let key = key.to_usize().expect("Dictionary index not usize");
dictionary_comparison.value_unchecked(key)
}
})
})
.collect();
Ok(result)
} There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thx much for putting this together and the explanation. I'll work on implementing it! |
||
|
||
// pub fn eq_dict_utf8_scalar<OffsetSize>( | ||
// left: &DictionaryArray<OffsetSize>, | ||
// right: &str, | ||
// ) -> Result<BooleanArray> | ||
// where | ||
// OffsetSize: StringOffsetSizeTrait + ArrowPrimitiveType, | ||
// { | ||
// #[cfg(not(feature = "simd"))] | ||
// return compare_dict_op_scalar!(left, OffsetSize, right, |a, b| a == b); | ||
// } | ||
/// Perform `left != right` operation on two [`PrimitiveArray`]s. | ||
pub fn neq<T>(left: &PrimitiveArray<T>, right: &PrimitiveArray<T>) -> Result<BooleanArray> | ||
where | ||
|
@@ -2032,6 +2112,44 @@ mod tests { | |
); | ||
} | ||
|
||
#[test] | ||
fn test_dict_eq_scalar() { | ||
let key_builder = PrimitiveBuilder::<UInt8Type>::new(3); | ||
let value_builder = PrimitiveBuilder::<UInt8Type>::new(2); | ||
let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); | ||
builder.append(123).unwrap(); | ||
builder.append_null().unwrap(); | ||
builder.append(223).unwrap(); | ||
let array = builder.finish(); | ||
let a_eq = eq_dict_scalar(&array, 123).unwrap(); | ||
assert_eq!( | ||
a_eq, | ||
BooleanArray::from(vec![Some(true), None, Some(false)]) | ||
); | ||
} | ||
|
||
// #[test] | ||
// fn test_dict_eq_utf8_scalar() { | ||
// let a: DictionaryArray<Int8Type> = vec!["a", "b", "c"].into_iter().collect(); | ||
// let a_eq = eq_dict_utf8_scalar(&a, "b").unwrap(); | ||
// assert_eq!(a_eq, BooleanArray::from(vec![false, true, false])); | ||
// } | ||
// #[test] | ||
// fn test_dict_neq_scalar() { | ||
// let a: DictionaryArray<Int8Type> = | ||
// vec!["hi","hello", "world"].into_iter().collect(); | ||
// let a_eq = neq_dict_scalar(&a, "hello").unwrap(); | ||
// assert_eq!(a_eq, BooleanArray::from(vec![true, false, true])); | ||
// } | ||
|
||
// #[test] | ||
// fn test_dict_lt_scalar() { | ||
// let a: DictionaryArray<Int8Type> = | ||
// vec!["hi","hello", "world"].into_iter().collect(); | ||
// let a_eq = lt_dict_scalar(&a, "hi").unwrap(); | ||
// assert_eq!(a_eq, BooleanArray::from(vec![false, true, false])); | ||
// } | ||
|
||
macro_rules! test_utf8_scalar { | ||
($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => { | ||
#[test] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't think the
values()
(the dictionary size) has to the same as the size of the overall array 🤔There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if youre referring to the safety comment i just hadnt removed that yet