Skip to content

Commit

Permalink
Implement DictionaryArray support in neq_dyn, lt_dyn, lt_eq_dyn, gt_d…
Browse files Browse the repository at this point in the history
…yn, gt_eq_dyn (#1326)

* Implement DictionaryArray support in neq_dyn, lt_dyn, lt_eq_dyn, gt_dyn, gt_eq_dyn

* Fix clippy

* Fix format

* Add test

* For review comment and suggestion

* Allow reasonable boolean comparisons

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
  • Loading branch information
viirya and alamb authored Mar 1, 2022
1 parent e89777f commit 483a502
Showing 1 changed file with 143 additions and 31 deletions.
174 changes: 143 additions & 31 deletions arrow/src/compute/kernels/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2032,10 +2032,10 @@ macro_rules! typed_compares {

/// Applies $OP to $LEFT and $RIGHT which are two dictionaries which have (the same) key type $KT
macro_rules! typed_dict_cmp {
($LEFT: expr, $RIGHT: expr, $OP: expr, $KT: tt) => {{
($LEFT: expr, $RIGHT: expr, $OP: expr, $OP_BOOL: expr, $KT: tt) => {{
match ($LEFT.value_type(), $RIGHT.value_type()) {
(DataType::Boolean, DataType::Boolean) => {
cmp_dict_bool::<$KT, _>($LEFT, $RIGHT, $OP)
cmp_dict_bool::<$KT, _>($LEFT, $RIGHT, $OP_BOOL)
}
(DataType::Int8, DataType::Int8) => {
cmp_dict::<$KT, Int8Type, _>($LEFT, $RIGHT, $OP)
Expand Down Expand Up @@ -2141,49 +2141,49 @@ macro_rules! typed_dict_cmp {

macro_rules! typed_dict_compares {
// Applies `LEFT OP RIGHT` when `LEFT` and `RIGHT` both are `DictionaryArray`
($LEFT: expr, $RIGHT: expr, $OP: expr) => {{
($LEFT: expr, $RIGHT: expr, $OP: expr, $OP_BOOL: expr) => {{
match ($LEFT.data_type(), $RIGHT.data_type()) {
(DataType::Dictionary(left_key_type, _), DataType::Dictionary(right_key_type, _))=> {
match (left_key_type.as_ref(), right_key_type.as_ref()) {
(DataType::Int8, DataType::Int8) => {
let left = as_dictionary_array::<Int8Type>($LEFT);
let right = as_dictionary_array::<Int8Type>($RIGHT);
typed_dict_cmp!(left, right, $OP, Int8Type)
typed_dict_cmp!(left, right, $OP, $OP_BOOL, Int8Type)
}
(DataType::Int16, DataType::Int16) => {
let left = as_dictionary_array::<Int16Type>($LEFT);
let right = as_dictionary_array::<Int16Type>($RIGHT);
typed_dict_cmp!(left, right, $OP, Int16Type)
typed_dict_cmp!(left, right, $OP, $OP_BOOL, Int16Type)
}
(DataType::Int32, DataType::Int32) => {
let left = as_dictionary_array::<Int32Type>($LEFT);
let right = as_dictionary_array::<Int32Type>($RIGHT);
typed_dict_cmp!(left, right, $OP, Int32Type)
typed_dict_cmp!(left, right, $OP, $OP_BOOL, Int32Type)
}
(DataType::Int64, DataType::Int64) => {
let left = as_dictionary_array::<Int64Type>($LEFT);
let right = as_dictionary_array::<Int64Type>($RIGHT);
typed_dict_cmp!(left, right, $OP, Int64Type)
typed_dict_cmp!(left, right, $OP, $OP_BOOL, Int64Type)
}
(DataType::UInt8, DataType::UInt8) => {
let left = as_dictionary_array::<UInt8Type>($LEFT);
let right = as_dictionary_array::<UInt8Type>($RIGHT);
typed_dict_cmp!(left, right, $OP, UInt8Type)
typed_dict_cmp!(left, right, $OP, $OP_BOOL, UInt8Type)
}
(DataType::UInt16, DataType::UInt16) => {
let left = as_dictionary_array::<UInt16Type>($LEFT);
let right = as_dictionary_array::<UInt16Type>($RIGHT);
typed_dict_cmp!(left, right, $OP, UInt16Type)
typed_dict_cmp!(left, right, $OP, $OP_BOOL, UInt16Type)
}
(DataType::UInt32, DataType::UInt32) => {
let left = as_dictionary_array::<UInt32Type>($LEFT);
let right = as_dictionary_array::<UInt32Type>($RIGHT);
typed_dict_cmp!(left, right, $OP, UInt32Type)
typed_dict_cmp!(left, right, $OP, $OP_BOOL, UInt32Type)
}
(DataType::UInt64, DataType::UInt64) => {
let left = as_dictionary_array::<UInt64Type>($LEFT);
let right = as_dictionary_array::<UInt64Type>($RIGHT);
typed_dict_cmp!(left, right, $OP, UInt64Type)
typed_dict_cmp!(left, right, $OP, $OP_BOOL, UInt64Type)
}
(t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!(
"Comparing dictionary arrays of type {} is not yet implemented",
Expand Down Expand Up @@ -2317,7 +2317,7 @@ where
pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
match left.data_type() {
DataType::Dictionary(_, _) => {
typed_dict_compares!(left, right, |a, b| a == b)
typed_dict_compares!(left, right, |a, b| a == b, |a, b| a == b)
}
_ => typed_compares!(left, right, eq_bool, eq, eq_utf8, eq_binary),
}
Expand All @@ -2340,7 +2340,12 @@ pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
/// assert_eq!(BooleanArray::from(vec![Some(false), None, Some(true)]), result);
/// ```
pub fn neq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
typed_compares!(left, right, neq_bool, neq, neq_utf8, neq_binary)
match left.data_type() {
DataType::Dictionary(_, _) => {
typed_dict_compares!(left, right, |a, b| a != b, |a, b| a != b)
}
_ => typed_compares!(left, right, neq_bool, neq, neq_utf8, neq_binary),
}
}

/// Perform `left < right` operation on two (dynamic) [`Array`]s.
Expand All @@ -2358,8 +2363,14 @@ pub fn neq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
/// let result = lt_dyn(&array1, &array2).unwrap();
/// assert_eq!(BooleanArray::from(vec![Some(true), Some(false), None]), result);
/// ```
#[allow(clippy::bool_comparison)]
pub fn lt_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
typed_compares!(left, right, lt_bool, lt, lt_utf8, lt_binary)
match left.data_type() {
DataType::Dictionary(_, _) => {
typed_dict_compares!(left, right, |a, b| a < b, |a, b| a < b)
}
_ => typed_compares!(left, right, lt_bool, lt, lt_utf8, lt_binary),
}
}

/// Perform `left <= right` operation on two (dynamic) [`Array`]s.
Expand All @@ -2378,7 +2389,12 @@ pub fn lt_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
/// assert_eq!(BooleanArray::from(vec![Some(false), Some(true), Some(true), None]), result);
/// ```
pub fn lt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
typed_compares!(left, right, lt_eq_bool, lt_eq, lt_eq_utf8, lt_eq_binary)
match left.data_type() {
DataType::Dictionary(_, _) => {
typed_dict_compares!(left, right, |a, b| a <= b, |a, b| a <= b)
}
_ => typed_compares!(left, right, lt_eq_bool, lt_eq, lt_eq_utf8, lt_eq_binary),
}
}

/// Perform `left > right` operation on two (dynamic) [`Array`]s.
Expand All @@ -2395,8 +2411,14 @@ pub fn lt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
/// let result = gt_dyn(&array1, &array2).unwrap();
/// assert_eq!(BooleanArray::from(vec![Some(true), Some(false), None]), result);
/// ```
#[allow(clippy::bool_comparison)]
pub fn gt_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
typed_compares!(left, right, gt_bool, gt, gt_utf8, gt_binary)
match left.data_type() {
DataType::Dictionary(_, _) => {
typed_dict_compares!(left, right, |a, b| a > b, |a, b| a > b)
}
_ => typed_compares!(left, right, gt_bool, gt, gt_utf8, gt_binary),
}
}

/// Perform `left >= right` operation on two (dynamic) [`Array`]s.
Expand All @@ -2414,7 +2436,12 @@ pub fn gt_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
/// assert_eq!(BooleanArray::from(vec![Some(false), Some(true), None]), result);
/// ```
pub fn gt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
typed_compares!(left, right, gt_eq_bool, gt_eq, gt_eq_utf8, gt_eq_binary)
match left.data_type() {
DataType::Dictionary(_, _) => {
typed_dict_compares!(left, right, |a, b| a >= b, |a, b| a >= b)
}
_ => typed_compares!(left, right, gt_eq_bool, gt_eq, gt_eq_utf8, gt_eq_binary),
}
}

/// Perform `left == right` operation on two [`PrimitiveArray`]s.
Expand Down Expand Up @@ -4663,7 +4690,7 @@ mod tests {
}

#[test]
fn test_eq_dyn_dictionary_i8_array() {
fn test_eq_dyn_neq_dyn_dictionary_i8_array() {
// Construct a value array
let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]);

Expand All @@ -4673,12 +4700,17 @@ mod tests {
let dict_array2 = DictionaryArray::try_new(&keys2, &values).unwrap();

let result = eq_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
assert_eq!(result.unwrap(), BooleanArray::from(vec![true, false, true]));

let result = neq_dyn(&dict_array1, &dict_array2);
assert_eq!(
result.unwrap(),
BooleanArray::from(vec![false, true, false])
);
}

#[test]
fn test_eq_dyn_dictionary_u64_array() {
fn test_eq_dyn_neq_dyn_dictionary_u64_array() {
let values = UInt64Array::from_iter_values([10_u64, 11, 12, 13, 14, 15, 16, 17]);

let keys1 = UInt64Array::from_iter_values([1_u64, 3, 4]);
Expand All @@ -4689,15 +4721,17 @@ mod tests {
DictionaryArray::<UInt64Type>::try_new(&keys2, &values).unwrap();

let result = eq_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
assert_eq!(
result.unwrap(),
BooleanArray::from(vec![false, true, false])
);

let result = neq_dyn(&dict_array1, &dict_array2);
assert_eq!(result.unwrap(), BooleanArray::from(vec![true, false, true]));
}

#[test]
fn test_eq_dyn_dictionary_utf8_array() {
fn test_eq_dyn_neq_dyn_dictionary_utf8_array() {
let test1 = vec!["a", "a", "b", "c"];
let test2 = vec!["a", "b", "b", "c"];

Expand All @@ -4711,15 +4745,20 @@ mod tests {
.collect();

let result = eq_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
assert_eq!(
result.unwrap(),
BooleanArray::from(vec![Some(true), None, None, Some(true)])
);

let result = neq_dyn(&dict_array1, &dict_array2);
assert_eq!(
result.unwrap(),
BooleanArray::from(vec![Some(false), None, None, Some(false)])
);
}

#[test]
fn test_eq_dyn_dictionary_binary_array() {
fn test_eq_dyn_neq_dyn_dictionary_binary_array() {
let values: BinaryArray = ["hello", "", "parquet"]
.into_iter()
.map(|b| Some(b.as_bytes()))
Expand All @@ -4733,15 +4772,17 @@ mod tests {
DictionaryArray::<UInt64Type>::try_new(&keys2, &values).unwrap();

let result = eq_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
assert_eq!(
result.unwrap(),
BooleanArray::from(vec![true, false, false])
);

let result = neq_dyn(&dict_array1, &dict_array2);
assert_eq!(result.unwrap(), BooleanArray::from(vec![false, true, true]));
}

#[test]
fn test_eq_dyn_dictionary_interval_array() {
fn test_eq_dyn_neq_dyn_dictionary_interval_array() {
let values = IntervalDayTimeArray::from(vec![1, 6, 10, 2, 3, 5]);

let keys1 = UInt64Array::from_iter_values([1_u64, 0, 3]);
Expand All @@ -4752,12 +4793,17 @@ mod tests {
DictionaryArray::<UInt64Type>::try_new(&keys2, &values).unwrap();

let result = eq_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
assert_eq!(result.unwrap(), BooleanArray::from(vec![false, true, true]));

let result = neq_dyn(&dict_array1, &dict_array2);
assert_eq!(
result.unwrap(),
BooleanArray::from(vec![true, false, false])
);
}

#[test]
fn test_eq_dyn_dictionary_date_array() {
fn test_eq_dyn_neq_dyn_dictionary_date_array() {
let values = Date32Array::from(vec![1, 6, 10, 2, 3, 5]);

let keys1 = UInt64Array::from_iter_values([1_u64, 0, 3]);
Expand All @@ -4768,12 +4814,17 @@ mod tests {
DictionaryArray::<UInt64Type>::try_new(&keys2, &values).unwrap();

let result = eq_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
assert_eq!(result.unwrap(), BooleanArray::from(vec![false, true, true]));

let result = neq_dyn(&dict_array1, &dict_array2);
assert_eq!(
result.unwrap(),
BooleanArray::from(vec![true, false, false])
);
}

#[test]
fn test_eq_dyn_dictionary_bool_array() {
fn test_eq_dyn_neq_dyn_dictionary_bool_array() {
let values = BooleanArray::from(vec![true, false]);

let keys1 = UInt64Array::from_iter_values([1_u64, 1, 1]);
Expand All @@ -4784,10 +4835,71 @@ mod tests {
DictionaryArray::<UInt64Type>::try_new(&keys2, &values).unwrap();

let result = eq_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
assert_eq!(
result.unwrap(),
BooleanArray::from(vec![false, true, false])
);

let result = neq_dyn(&dict_array1, &dict_array2);
assert_eq!(result.unwrap(), BooleanArray::from(vec![true, false, true]));
}

#[test]
fn test_lt_dyn_gt_dyn_dictionary_i8_array() {
// Construct a value array
let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]);

let keys1 = Int8Array::from_iter_values([3_i8, 4, 4]);
let keys2 = Int8Array::from_iter_values([4_i8, 3, 4]);
let dict_array1 = DictionaryArray::try_new(&keys1, &values).unwrap();
let dict_array2 = DictionaryArray::try_new(&keys2, &values).unwrap();

let result = lt_dyn(&dict_array1, &dict_array2);
assert_eq!(
result.unwrap(),
BooleanArray::from(vec![true, false, false])
);

let result = lt_eq_dyn(&dict_array1, &dict_array2);
assert_eq!(result.unwrap(), BooleanArray::from(vec![true, false, true]));

let result = gt_dyn(&dict_array1, &dict_array2);
assert_eq!(
result.unwrap(),
BooleanArray::from(vec![false, true, false])
);

let result = gt_eq_dyn(&dict_array1, &dict_array2);
assert_eq!(result.unwrap(), BooleanArray::from(vec![false, true, true]));
}

#[test]
fn test_lt_dyn_gt_dyn_dictionary_bool_array() {
let values = BooleanArray::from(vec![true, false]);

let keys1 = UInt64Array::from_iter_values([1_u64, 1, 0]);
let keys2 = UInt64Array::from_iter_values([0_u64, 1, 1]);
let dict_array1 =
DictionaryArray::<UInt64Type>::try_new(&keys1, &values).unwrap();
let dict_array2 =
DictionaryArray::<UInt64Type>::try_new(&keys2, &values).unwrap();

let result = lt_dyn(&dict_array1, &dict_array2);
assert_eq!(
result.unwrap(),
BooleanArray::from(vec![true, false, false])
);

let result = lt_eq_dyn(&dict_array1, &dict_array2);
assert_eq!(result.unwrap(), BooleanArray::from(vec![true, true, false]));

let result = gt_dyn(&dict_array1, &dict_array2);
assert_eq!(
result.unwrap(),
BooleanArray::from(vec![false, false, true])
);

let result = gt_eq_dyn(&dict_array1, &dict_array2);
assert_eq!(result.unwrap(), BooleanArray::from(vec![false, true, true]));
}
}

0 comments on commit 483a502

Please sign in to comment.