Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add scalar comparison kernels for DictionaryArray #984

Closed
Closed
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 120 additions & 2 deletions arrow/src/compute/kernels/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ use crate::buffer::{bitwise_bin_op_helper, buffer_unary_not, Buffer, MutableBuff
use crate::compute::binary_boolean_kernel;
use crate::compute::util::combine_option_bitmap;
use crate::datatypes::{
ArrowNumericType, DataType, Float32Type, Float64Type, Int16Type, Int32Type,
Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
ArrowNativeType, ArrowNumericType, ArrowPrimitiveType, DataType, Float32Type,
Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type,
UInt64Type, UInt8Type,
};
use crate::error::{ArrowError, Result};
use crate::util::bit_util;
Expand Down Expand Up @@ -200,6 +201,54 @@ macro_rules! compare_op_scalar_primitive {
}};
}

macro_rules! compare_dict_op_scalar {
($left:expr, $T:ident, $right:expr, $op:expr) => {{
let null_bit_buffer = $left
.data()
.null_buffer()
.map(|b| b.bit_slice($left.offset(), $left.len()));

let values = $left.values();

let array = values
.as_any()
.downcast_ref::<PrimitiveArray<$T>>()
.unwrap();

// Safety:
// `i < $left.len()`
let comparison: Vec<bool> = (0..array.len())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't think the values() (the dictionary size) has to the same as the size of the overall array 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if youre referring to the safety comment i just hadnt removed that yet

.map(|i| unsafe { $op(array.value_unchecked(i), $right) })
.collect();

let result: Vec<bool> = (0..$left.keys().len())
.map(|key| {
let index = $left.keys().value(key);
comparison[index
.to_usize()
.expect(format!("Failed at idx {:?}", index).as_str())]
})
.collect();

// same as $left.len()
let buffer =
unsafe { MutableBuffer::from_trusted_len_iter_bool(result.into_iter()) };

let data = unsafe {
ArrayData::new_unchecked(
DataType::Boolean,
$left.len(),
None,
null_bit_buffer,
0,
vec![Buffer::from(buffer)],
vec![],
)
};
Ok(BooleanArray::from(data))
}};
}

/// Evaluate `op(left, right)` for [`PrimitiveArray`]s using a specified
/// comparison function.
pub fn no_simd_compare_op<T, F>(
Expand Down Expand Up @@ -693,6 +742,14 @@ pub fn eq_utf8_scalar<OffsetSize: StringOffsetSizeTrait>(
compare_op_scalar!(left, right, |a, b| a == b)
}

/// Perform `left == right` operation on [`DictionaryArray`] and a scalar.
// pub fn eq_dict<OffsetSize: ArrowPrimitiveType>(
// left: &DictionaryArray<OffsetSize>,
// right: &str,
// ) -> Result<BooleanArray> {
// compare_dict_op!(left, right, |a, b| a == b)
// }

#[inline]
fn binary_boolean_op<F>(
left: &BooleanArray,
Expand Down Expand Up @@ -1200,6 +1257,29 @@ where
return compare_op_scalar!(left, right, |a, b| a == b);
}

/// Perform `left == right` operation on a [`PrimitiveArray`] and a scalar value.
pub fn eq_dict_scalar<T>(
left: &DictionaryArray<T>,
right: T::Native,
) -> Result<BooleanArray>
where
T: ArrowNumericType,
{
#[cfg(not(feature = "simd"))]
println!("{}", std::any::type_name::<T>());
return compare_dict_op_scalar!(left, T, right, |a, b| a == b);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matthewmturner this is what I was trying to say.

I think the way you have this function with a single T generic parameter means one could not compare a DictionaryArray<Int8> (aka that has keys / indexes of Int8) that had values of type DataType::Unt16

Here is a sketch of how this might work:

/// Perform `left == right` operation on a [`DictionaryArray`] and a numeric scalar value.
pub fn eq_dict_scalar<T, K>(
    left: &DictionaryArray<K>,
    right: T::Native,
) -> Result<BooleanArray>
where
    T: ArrowNumericType,
    K: ArrowNumericType,
{
    // compare to the dictionary values (e.g if the dictionary is {A,
    // B} and the keys are {1,0,1,1} that represents the values B, A,
    // B, B.
    //
    // So we compare just the dictionary {A, B} values to `right` and
    //
    // TODO macro-ize this

    let dictionary_comparison = match left.values().data_type() {
        DataType::Int8 => {
            eq_scalar(as_primitive_array::<T>(left.values()), right)
        }
        // TODO fill in Int16, Int32, etc
        _ => unimplemented!("Should error: dictionary did not store values of type T")
    }?;

    // Required for safety below
    assert_eq!(dictionary_comparison.len(), left.values().len());

    // Now, look up the dictionary for each output
    let result: BooleanArray = left.keys()
        .iter()
        .map(|key| {
            // figure out how the dictionary element at this index
            // compared to the scalar
            key.map(|key| {
                // safety: the original array's indices were valid
                // `(0 .. left.values().len()` and dictionary_comparisoon
                // is the same size, checked above
                unsafe {
                    // it would be nice to avoid checking the conversion each time
                    let key = key.to_usize().expect("Dictionary index not usize");
                    dictionary_comparison.value_unchecked(key)
                }
            })
        })
        .collect();

    Ok(result)
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thx much for putting this together and the explanation. I'll work on implementing it!


// pub fn eq_dict_utf8_scalar<OffsetSize>(
// left: &DictionaryArray<OffsetSize>,
// right: &str,
// ) -> Result<BooleanArray>
// where
// OffsetSize: StringOffsetSizeTrait + ArrowPrimitiveType,
// {
// #[cfg(not(feature = "simd"))]
// return compare_dict_op_scalar!(left, OffsetSize, right, |a, b| a == b);
// }
/// Perform `left != right` operation on two [`PrimitiveArray`]s.
pub fn neq<T>(left: &PrimitiveArray<T>, right: &PrimitiveArray<T>) -> Result<BooleanArray>
where
Expand Down Expand Up @@ -2032,6 +2112,44 @@ mod tests {
);
}

#[test]
fn test_dict_eq_scalar() {
let key_builder = PrimitiveBuilder::<UInt8Type>::new(3);
let value_builder = PrimitiveBuilder::<UInt8Type>::new(2);
let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder);
builder.append(123).unwrap();
builder.append_null().unwrap();
builder.append(223).unwrap();
let array = builder.finish();
let a_eq = eq_dict_scalar(&array, 123).unwrap();
assert_eq!(
a_eq,
BooleanArray::from(vec![Some(true), None, Some(false)])
);
}

// #[test]
// fn test_dict_eq_utf8_scalar() {
// let a: DictionaryArray<Int8Type> = vec!["a", "b", "c"].into_iter().collect();
// let a_eq = eq_dict_utf8_scalar(&a, "b").unwrap();
// assert_eq!(a_eq, BooleanArray::from(vec![false, true, false]));
// }
// #[test]
// fn test_dict_neq_scalar() {
// let a: DictionaryArray<Int8Type> =
// vec!["hi","hello", "world"].into_iter().collect();
// let a_eq = neq_dict_scalar(&a, "hello").unwrap();
// assert_eq!(a_eq, BooleanArray::from(vec![true, false, true]));
// }

// #[test]
// fn test_dict_lt_scalar() {
// let a: DictionaryArray<Int8Type> =
// vec!["hi","hello", "world"].into_iter().collect();
// let a_eq = lt_dict_scalar(&a, "hi").unwrap();
// assert_eq!(a_eq, BooleanArray::from(vec![false, true, false]));
// }

macro_rules! test_utf8_scalar {
($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => {
#[test]
Expand Down