diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs index 3b65f339b804..f98e15d549af 100644 --- a/arrow/src/compute/kernels/comparison.rs +++ b/arrow/src/compute/kernels/comparison.rs @@ -27,8 +27,9 @@ use crate::buffer::{bitwise_bin_op_helper, buffer_unary_not, Buffer, MutableBuff use crate::compute::binary_boolean_kernel; use crate::compute::util::combine_option_bitmap; use crate::datatypes::{ - ArrowNumericType, DataType, Float32Type, Float64Type, Int16Type, Int32Type, - Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, + ArrowNativeType, ArrowNumericType, ArrowPrimitiveType, DataType, Float32Type, + Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, + UInt64Type, UInt8Type, }; use crate::error::{ArrowError, Result}; use crate::util::bit_util; @@ -200,6 +201,66 @@ macro_rules! compare_op_scalar_primitive { }}; } +macro_rules! compare_dict_op_scalar { + ($left:expr, $T:ident, $right:expr, $op:expr) => {{ + let null_bit_buffer = $left + .data() + .null_buffer() + .map(|b| b.bit_slice($left.offset(), $left.len())); + + let values = $left + .values() + .as_any() + .downcast_ref::>() + .unwrap(); + + // Safety: + // `i < $left.len()` + // let comparison: Vec = (0..array.len()) + // .map(|i| unsafe { $op(array.value_unchecked(i), $right) }) + // .collect(); + + let dict_comparison = eq_scalar(values, $right).unwrap(); + + let result: Vec = (0..$left.keys().len()) + .map(|key| { + let index = $left.keys().value(key); + dict_comparison.value( + index + .to_usize() + .expect(format!("Failed at idx {:?}", index).as_str()), + ) + }) + .collect(); + + // let result: Vec = (0..$left.keys().len()) + // .map(|key| { + // let index = $left.keys().value(key); + // comparison[index + // .to_usize() + // .expect(format!("Failed at idx {:?}", index).as_str())] + // }) + // .collect(); + + // same as $left.len() + let buffer = + unsafe { MutableBuffer::from_trusted_len_iter_bool(result.into_iter()) }; + + let data = unsafe { + ArrayData::new_unchecked( + DataType::Boolean, + $left.len(), + None, + null_bit_buffer, + 0, + vec![Buffer::from(buffer)], + vec![], + ) + }; + Ok(BooleanArray::from(data)) + }}; +} + /// Evaluate `op(left, right)` for [`PrimitiveArray`]s using a specified /// comparison function. pub fn no_simd_compare_op( @@ -1200,7 +1261,40 @@ where return compare_op_scalar!(left, right, |a, b| a == b); } -/// Perform `left != right` operation on two [`PrimitiveArray`]s. +/// Perform `left == right` operation on a [`PrimitiveArray`] and a numeric scalar value. +pub fn eq_dict_scalar( + left: &DictionaryArray, + right: T::Native, +) -> Result +where + T: ArrowNumericType, + K: ArrowNumericType, +{ + #[cfg(not(feature = "simd"))] + println!("{}", std::any::type_name::()); + let dict_comparison = match left.values().data_type() { + DataType::Int8 => eq_scalar(as_primitive_array::(left.values()), right), + _ => Err(ArrowError::ComputeError( + "Dictionary did not store values of type T".to_string(), + )), + }?; + + assert_eq!(dict_comparison.len(), left.values().len()); + + let result: BooleanArray = left + .keys() + .iter() + .map(|key| { + key.map(|key| unsafe { + let key = key.to_usize().expect("Dictionary index not usize"); + dict_comparison.value_unchecked(key) + }) + }) + .collect(); + + Ok(result) +} + pub fn neq(left: &PrimitiveArray, right: &PrimitiveArray) -> Result where T: ArrowNumericType, @@ -2032,6 +2126,44 @@ mod tests { ); } + #[test] + fn test_dict_eq_scalar() { + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = PrimitiveBuilder::::new(2); + let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); + builder.append(123).unwrap(); + builder.append_null().unwrap(); + builder.append(223).unwrap(); + let array = builder.finish(); + let a_eq = eq_dict_scalar(&array, 123).unwrap(); + assert_eq!( + a_eq, + BooleanArray::from(vec![Some(true), None, Some(false)]) + ); + } + + // #[test] + // fn test_dict_eq_utf8_scalar() { + // let a: DictionaryArray = vec!["a", "b", "c"].into_iter().collect(); + // let a_eq = eq_dict_utf8_scalar(&a, "b").unwrap(); + // assert_eq!(a_eq, BooleanArray::from(vec![false, true, false])); + // } + // #[test] + // fn test_dict_neq_scalar() { + // let a: DictionaryArray = + // vec!["hi","hello", "world"].into_iter().collect(); + // let a_eq = neq_dict_scalar(&a, "hello").unwrap(); + // assert_eq!(a_eq, BooleanArray::from(vec![true, false, true])); + // } + + // #[test] + // fn test_dict_lt_scalar() { + // let a: DictionaryArray = + // vec!["hi","hello", "world"].into_iter().collect(); + // let a_eq = lt_dict_scalar(&a, "hi").unwrap(); + // assert_eq!(a_eq, BooleanArray::from(vec![false, true, false])); + // } + macro_rules! test_utf8_scalar { ($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => { #[test]