-
Notifications
You must be signed in to change notification settings - Fork 817
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add scalar comparison kernels for DictionaryArray #984
Changes from 2 commits
c5c7f35
155d60f
4896762
0d5d1b1
c147869
b5f04c5
ee7997c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,8 +27,9 @@ use crate::buffer::{bitwise_bin_op_helper, buffer_unary_not, Buffer, MutableBuff | |
use crate::compute::binary_boolean_kernel; | ||
use crate::compute::util::combine_option_bitmap; | ||
use crate::datatypes::{ | ||
ArrowNumericType, DataType, Float32Type, Float64Type, Int16Type, Int32Type, | ||
Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, | ||
ArrowNativeType, ArrowNumericType, ArrowPrimitiveType, DataType, Dictionary, | ||
Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, | ||
UInt32Type, UInt64Type, UInt8Type, | ||
}; | ||
use crate::error::{ArrowError, Result}; | ||
use crate::util::bit_util; | ||
|
@@ -200,6 +201,42 @@ macro_rules! compare_op_scalar_primitive { | |
}}; | ||
} | ||
|
||
macro_rules! compare_dict_op_scalar { | ||
($left:expr, $right:expr, $op:expr) => {{ | ||
let null_bit_buffer = $left | ||
.data() | ||
.null_buffer() | ||
.map(|b| b.bit_slice($left.offset(), $left.len())); | ||
|
||
let values = $left | ||
.values() | ||
.as_any() | ||
.downcast_ref::<StringArray>() | ||
.unwrap(); | ||
|
||
// Safety: | ||
// `i < $left.len()` | ||
let comparison = (0..$left.len()).map(|i| unsafe { | ||
let key = $left.keys().value_unchecked(i).to_usize().unwrap(); | ||
$op(values.value_unchecked(key), $right) | ||
}); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think one of the main points of this ticket is to avoid the call here to I like to think about the goal in by thinking "what would happen with DictionaryArray with 1000000 entries but a dictionary of size 1?" -- the way you have this PR, I think we would call So the pattern I think we are looking, at least for the constant kernels is: In pseudo code:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense, thanks for explanation. I am looking into this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @alamb im struggling with the second step in your pseudocode given that my understanding is that the values could be of any There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🤔 Yes this is definitely tricky. Maybe taking a step back, and think about the usecase: comparing DictionaryArrays to literals. For example, if you look at the comparison kernels (for
With each being typed based on the type of scalar (because the arrays are typed) The issue with a So i am thinking we would need something like
where each of those kernels would be able to downcast the array appropriately. However, having three functions for each dict kernel seems somewhat crazy. That is where my dyn idea was coming from. If we are going to add three new kernels for each operator (
etc Which handle DictionaryArray as well as dispatching to the other Does that make sense? I can try and sketch out the interface this weekend sometime There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks for explanation! yes, it does make sense. i think i was trying to do too much in my macros / functions which was causing my confusion. i think if i can get one of the below to work that should give me my baseline to do the rest.
|
||
// same as $left.len() | ||
let buffer = unsafe { MutableBuffer::from_trusted_len_iter_bool(comparison) }; | ||
|
||
let data = unsafe { | ||
ArrayData::new_unchecked( | ||
DataType::Boolean, | ||
$left.len(), | ||
None, | ||
null_bit_buffer, | ||
0, | ||
vec![Buffer::from(buffer)], | ||
vec![], | ||
) | ||
}; | ||
Ok(BooleanArray::from(data)) | ||
}}; | ||
} | ||
/// Evaluate `op(left, right)` for [`PrimitiveArray`]s using a specified | ||
/// comparison function. | ||
pub fn no_simd_compare_op<T, F>( | ||
|
@@ -693,6 +730,14 @@ pub fn eq_utf8_scalar<OffsetSize: StringOffsetSizeTrait>( | |
compare_op_scalar!(left, right, |a, b| a == b) | ||
} | ||
|
||
/// Perform `left == right` operation on [`DictionaryArray`] and a scalar. | ||
pub fn eq_dict_scalar<OffsetSize: ArrowPrimitiveType>( | ||
left: &DictionaryArray<OffsetSize>, | ||
right: &str, | ||
) -> Result<BooleanArray> { | ||
compare_dict_op_scalar!(left, right, |a, b| a == b) | ||
} | ||
|
||
#[inline] | ||
fn binary_boolean_op<F>( | ||
left: &BooleanArray, | ||
|
@@ -802,6 +847,14 @@ pub fn neq_utf8_scalar<OffsetSize: StringOffsetSizeTrait>( | |
compare_op_scalar!(left, right, |a, b| a != b) | ||
} | ||
|
||
/// Perform `left != right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar. | ||
pub fn neq_dict_scalar<OffsetSize: ArrowPrimitiveType>( | ||
left: &DictionaryArray<OffsetSize>, | ||
right: &str, | ||
) -> Result<BooleanArray> { | ||
compare_dict_op_scalar!(left, right, |a, b| a != b) | ||
} | ||
|
||
/// Perform `left < right` operation on [`StringArray`] / [`LargeStringArray`]. | ||
pub fn lt_utf8<OffsetSize: StringOffsetSizeTrait>( | ||
left: &GenericStringArray<OffsetSize>, | ||
|
@@ -818,6 +871,14 @@ pub fn lt_utf8_scalar<OffsetSize: StringOffsetSizeTrait>( | |
compare_op_scalar!(left, right, |a, b| a < b) | ||
} | ||
|
||
/// Perform `left < right` operation on [`DictionaryArray`] and a scalar. | ||
pub fn lt_dict_scalar<OffsetSize: ArrowPrimitiveType>( | ||
left: &DictionaryArray<OffsetSize>, | ||
right: &str, | ||
) -> Result<BooleanArray> { | ||
compare_dict_op_scalar!(left, right, |a, b| a < b) | ||
} | ||
|
||
/// Perform `left <= right` operation on [`StringArray`] / [`LargeStringArray`]. | ||
pub fn lt_eq_utf8<OffsetSize: StringOffsetSizeTrait>( | ||
left: &GenericStringArray<OffsetSize>, | ||
|
@@ -834,6 +895,14 @@ pub fn lt_eq_utf8_scalar<OffsetSize: StringOffsetSizeTrait>( | |
compare_op_scalar!(left, right, |a, b| a <= b) | ||
} | ||
|
||
/// Perform `left <= right` operation on [`DictionaryArray`] and a scalar. | ||
pub fn lt_eq_dict_scalar<OffsetSize: ArrowPrimitiveType>( | ||
left: &DictionaryArray<OffsetSize>, | ||
right: &str, | ||
) -> Result<BooleanArray> { | ||
compare_dict_op_scalar!(left, right, |a, b| a <= b) | ||
} | ||
|
||
/// Perform `left > right` operation on [`StringArray`] / [`LargeStringArray`]. | ||
pub fn gt_utf8<OffsetSize: StringOffsetSizeTrait>( | ||
left: &GenericStringArray<OffsetSize>, | ||
|
@@ -850,6 +919,14 @@ pub fn gt_utf8_scalar<OffsetSize: StringOffsetSizeTrait>( | |
compare_op_scalar!(left, right, |a, b| a > b) | ||
} | ||
|
||
/// Perform `left > right` operation on [`DictionaryArray`] and a scalar. | ||
pub fn gt_dict_scalar<OffsetSize: ArrowPrimitiveType>( | ||
left: &DictionaryArray<OffsetSize>, | ||
right: &str, | ||
) -> Result<BooleanArray> { | ||
compare_dict_op_scalar!(left, right, |a, b| a > b) | ||
} | ||
|
||
/// Perform `left >= right` operation on [`StringArray`] / [`LargeStringArray`]. | ||
pub fn gt_eq_utf8<OffsetSize: StringOffsetSizeTrait>( | ||
left: &GenericStringArray<OffsetSize>, | ||
|
@@ -866,6 +943,14 @@ pub fn gt_eq_utf8_scalar<OffsetSize: StringOffsetSizeTrait>( | |
compare_op_scalar!(left, right, |a, b| a >= b) | ||
} | ||
|
||
/// Perform `left >= right` operation on [`DictionaryArray`] and a scalar. | ||
pub fn gt_eq_dict_scalar<OffsetSize: ArrowPrimitiveType>( | ||
left: &DictionaryArray<OffsetSize>, | ||
right: &str, | ||
) -> Result<BooleanArray> { | ||
compare_dict_op_scalar!(left, right, |a, b| a >= b) | ||
} | ||
|
||
/// Helper function to perform boolean lambda function on values from two arrays using | ||
/// SIMD. | ||
#[cfg(feature = "simd")] | ||
|
@@ -2032,6 +2117,30 @@ mod tests { | |
); | ||
} | ||
|
||
#[test] | ||
fn test_dict_eq_scalar() { | ||
let a: DictionaryArray<Int8Type> = | ||
vec!["hi","hello", "world"].into_iter().collect(); | ||
let a_eq = eq_dict_scalar(&a, "hello").unwrap(); | ||
assert_eq!(a_eq, BooleanArray::from(vec![false, true, false])); | ||
} | ||
|
||
#[test] | ||
fn test_dict_neq_scalar() { | ||
let a: DictionaryArray<Int8Type> = | ||
vec!["hi","hello", "world"].into_iter().collect(); | ||
let a_eq = neq_dict_scalar(&a, "hello").unwrap(); | ||
assert_eq!(a_eq, BooleanArray::from(vec![true, false, true])); | ||
} | ||
|
||
#[test] | ||
fn test_dict_lt_scalar() { | ||
let a: DictionaryArray<Int8Type> = | ||
vec!["hi","hello", "world"].into_iter().collect(); | ||
let a_eq = lt_dict_scalar(&a, "hi").unwrap(); | ||
assert_eq!(a_eq, BooleanArray::from(vec![false, true, false])); | ||
} | ||
|
||
macro_rules! test_utf8_scalar { | ||
($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => { | ||
#[test] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The values array can be anything (not always a
StringArray
) -- perhaps this would be a good place to use thedyn_XX
kernels -- to compare the values array with$right
)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from my understanding of the dyn kernels those cant be used when comparing to constant right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤔 yes you are correct -- we would need to add
dyn_XX_lit
type kernels, but that seems a bit overkill for this PRThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if the primary use case for this PR was comparing dict array to constant then maybe it makes sense for me to do a separate PR for that first and then come back to this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think focusing on the usecase of comparing dict array to constant is the best choice for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok! Will start with that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@alamb ive been reviewing this but i think i might be missing something. my understanding is that my code above is for getting the dictionary values, which can be of any type (of course above im only handling
StringArray
).But then you mention using the new
dyn_xx
kernels / creatingdyn_xx_lit
kernels. Since theres no actual compute being done here, what would thedyn
kernels be used for? Or were you referring to using the kernels to replace more than just that section of code?to me it looks like i need a macro to downcast
DictionaryArray.values()
into whatever type the values are, and then i could use something likedyn_xx_lit
on that in order to get the comparison results. Is this roughly what you had in mind?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am very sorry for confusing this conversation with mentioning
dyn_xx_lit
.What I was (inarticulately) trying to say was that once you have
eq_dict_scalar
(and you will likely also needeq_dict_scalar_utf8
) we will end up with several different ways to compare an array to a scalar, depending on the array typeSo I was thinking ahead to adding functions like
But definitely not for this PR