Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement DictionaryArray support in neq_dyn, lt_dyn, lt_eq_dyn, gt_dyn, gt_eq_dyn #1326

Merged
merged 7 commits into from
Mar 1, 2022
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 156 additions & 24 deletions arrow/src/compute/kernels/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2032,10 +2032,10 @@ macro_rules! typed_compares {

/// Applies $OP to $LEFT and $RIGHT which are two dictionaries which have (the same) key type $KT
macro_rules! typed_dict_cmp {
($LEFT: expr, $RIGHT: expr, $OP: expr, $KT: tt) => {{
($LEFT: expr, $RIGHT: expr, $OP: expr, $OP_BOOL: expr, $KT: tt) => {{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 nice readability improvement

match ($LEFT.value_type(), $RIGHT.value_type()) {
(DataType::Boolean, DataType::Boolean) => {
cmp_dict_bool::<$KT, _>($LEFT, $RIGHT, $OP)
cmp_dict_bool::<$KT, _>($LEFT, $RIGHT, $OP_BOOL)
}
(DataType::Int8, DataType::Int8) => {
cmp_dict::<$KT, Int8Type, _>($LEFT, $RIGHT, $OP)
Expand Down Expand Up @@ -2141,49 +2141,49 @@ macro_rules! typed_dict_cmp {

macro_rules! typed_dict_compares {
// Applies `LEFT OP RIGHT` when `LEFT` and `RIGHT` both are `DictionaryArray`
($LEFT: expr, $RIGHT: expr, $OP: expr) => {{
($LEFT: expr, $RIGHT: expr, $OP: expr, $OP_BOOL: expr) => {{
match ($LEFT.data_type(), $RIGHT.data_type()) {
(DataType::Dictionary(left_key_type, _), DataType::Dictionary(right_key_type, _))=> {
match (left_key_type.as_ref(), right_key_type.as_ref()) {
(DataType::Int8, DataType::Int8) => {
let left = as_dictionary_array::<Int8Type>($LEFT);
let right = as_dictionary_array::<Int8Type>($RIGHT);
typed_dict_cmp!(left, right, $OP, Int8Type)
typed_dict_cmp!(left, right, $OP, $OP_BOOL, Int8Type)
}
(DataType::Int16, DataType::Int16) => {
let left = as_dictionary_array::<Int16Type>($LEFT);
let right = as_dictionary_array::<Int16Type>($RIGHT);
typed_dict_cmp!(left, right, $OP, Int16Type)
typed_dict_cmp!(left, right, $OP, $OP_BOOL, Int16Type)
}
(DataType::Int32, DataType::Int32) => {
let left = as_dictionary_array::<Int32Type>($LEFT);
let right = as_dictionary_array::<Int32Type>($RIGHT);
typed_dict_cmp!(left, right, $OP, Int32Type)
typed_dict_cmp!(left, right, $OP, $OP_BOOL, Int32Type)
}
(DataType::Int64, DataType::Int64) => {
let left = as_dictionary_array::<Int64Type>($LEFT);
let right = as_dictionary_array::<Int64Type>($RIGHT);
typed_dict_cmp!(left, right, $OP, Int64Type)
typed_dict_cmp!(left, right, $OP, $OP_BOOL, Int64Type)
}
(DataType::UInt8, DataType::UInt8) => {
let left = as_dictionary_array::<UInt8Type>($LEFT);
let right = as_dictionary_array::<UInt8Type>($RIGHT);
typed_dict_cmp!(left, right, $OP, UInt8Type)
typed_dict_cmp!(left, right, $OP, $OP_BOOL, UInt8Type)
}
(DataType::UInt16, DataType::UInt16) => {
let left = as_dictionary_array::<UInt16Type>($LEFT);
let right = as_dictionary_array::<UInt16Type>($RIGHT);
typed_dict_cmp!(left, right, $OP, UInt16Type)
typed_dict_cmp!(left, right, $OP, $OP_BOOL, UInt16Type)
}
(DataType::UInt32, DataType::UInt32) => {
let left = as_dictionary_array::<UInt32Type>($LEFT);
let right = as_dictionary_array::<UInt32Type>($RIGHT);
typed_dict_cmp!(left, right, $OP, UInt32Type)
typed_dict_cmp!(left, right, $OP, $OP_BOOL, UInt32Type)
}
(DataType::UInt64, DataType::UInt64) => {
let left = as_dictionary_array::<UInt64Type>($LEFT);
let right = as_dictionary_array::<UInt64Type>($RIGHT);
typed_dict_cmp!(left, right, $OP, UInt64Type)
typed_dict_cmp!(left, right, $OP, $OP_BOOL, UInt64Type)
}
(t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!(
"Comparing dictionary arrays of type {} is not yet implemented",
Expand Down Expand Up @@ -2318,7 +2318,7 @@ where
pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
match left.data_type() {
DataType::Dictionary(_, _) => {
typed_dict_compares!(left, right, |a, b| a == b)
typed_dict_compares!(left, right, |a, b| a == b, |a, b| !(a ^ b))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this change -- I think the a == b is easier to understand and I would expect that llvm would create optimized code for whatever was being compared.

If this is clippy being silly about comparing booleans perhaps we can just ignore the lint

Suggested change
typed_dict_compares!(left, right, |a, b| a == b, |a, b| !(a ^ b))
typed_dict_compares!(left, right, |a, b| a == b, |a, b| a == b)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, okay, I wrote it like you suggest at first, but changed it basically to make clippy happy. 😄
If we can ignore that, then I can change back.

Copy link
Contributor

@alamb alamb Feb 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can ignore it. I think clippy is somewhat confused probably when the parameters are boolean

}
_ => typed_compares!(left, right, eq_bool, eq, eq_utf8, eq_binary),
}
Expand All @@ -2341,7 +2341,12 @@ pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
/// assert_eq!(BooleanArray::from(vec![Some(false), None, Some(true)]), result);
/// ```
pub fn neq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
typed_compares!(left, right, neq_bool, neq, neq_utf8, neq_binary)
match left.data_type() {
DataType::Dictionary(_, _) => {
typed_dict_compares!(left, right, |a, b| a != b, |a, b| (a ^ b))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
typed_dict_compares!(left, right, |a, b| a != b, |a, b| (a ^ b))
typed_dict_compares!(left, right, |a, b| a != b, |a, b| a != b)

}
_ => typed_compares!(left, right, neq_bool, neq, neq_utf8, neq_binary),
}
}

/// Perform `left < right` operation on two (dynamic) [`Array`]s.
Expand All @@ -2360,7 +2365,12 @@ pub fn neq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
/// assert_eq!(BooleanArray::from(vec![Some(true), Some(false), None]), result);
/// ```
pub fn lt_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
typed_compares!(left, right, lt_bool, lt, lt_utf8, lt_binary)
match left.data_type() {
DataType::Dictionary(_, _) => {
typed_dict_compares!(left, right, |a, b| a < b, |a, b| (!a) & b)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
typed_dict_compares!(left, right, |a, b| a < b, |a, b| (!a) & b)
typed_dict_compares!(left, right, |a, b| a < b, |a, b| a < b)

}
_ => typed_compares!(left, right, lt_bool, lt, lt_utf8, lt_binary),
}
}

/// Perform `left <= right` operation on two (dynamic) [`Array`]s.
Expand All @@ -2379,7 +2389,12 @@ pub fn lt_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
/// assert_eq!(BooleanArray::from(vec![Some(false), Some(true), Some(true), None]), result);
/// ```
pub fn lt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
typed_compares!(left, right, lt_eq_bool, lt_eq, lt_eq_utf8, lt_eq_binary)
match left.data_type() {
DataType::Dictionary(_, _) => {
typed_dict_compares!(left, right, |a, b| a <= b, |a, b| !(a & (!b)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
typed_dict_compares!(left, right, |a, b| a <= b, |a, b| !(a & (!b)))
typed_dict_compares!(left, right, |a, b| a <= b, |a, b| a <= b)

}
_ => typed_compares!(left, right, lt_eq_bool, lt_eq, lt_eq_utf8, lt_eq_binary),
}
}

/// Perform `left > right` operation on two (dynamic) [`Array`]s.
Expand All @@ -2397,7 +2412,12 @@ pub fn lt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
/// assert_eq!(BooleanArray::from(vec![Some(true), Some(false), None]), result);
/// ```
pub fn gt_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
typed_compares!(left, right, gt_bool, gt, gt_utf8, gt_binary)
match left.data_type() {
DataType::Dictionary(_, _) => {
typed_dict_compares!(left, right, |a, b| a > b, |a, b| a & (!b))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
typed_dict_compares!(left, right, |a, b| a > b, |a, b| a & (!b))
typed_dict_compares!(left, right, |a, b| a > b, |a, b| a > b)

}
_ => typed_compares!(left, right, gt_bool, gt, gt_utf8, gt_binary),
}
}

/// Perform `left >= right` operation on two (dynamic) [`Array`]s.
Expand All @@ -2415,7 +2435,12 @@ pub fn gt_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
/// assert_eq!(BooleanArray::from(vec![Some(false), Some(true), None]), result);
/// ```
pub fn gt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result<BooleanArray> {
typed_compares!(left, right, gt_eq_bool, gt_eq, gt_eq_utf8, gt_eq_binary)
match left.data_type() {
DataType::Dictionary(_, _) => {
typed_dict_compares!(left, right, |a, b| a >= b, |a, b| !((!a) & b))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
typed_dict_compares!(left, right, |a, b| a >= b, |a, b| !((!a) & b))
typed_dict_compares!(left, right, |a, b| a >= b, |a, b| a >= b)

}
_ => typed_compares!(left, right, gt_eq_bool, gt_eq, gt_eq_utf8, gt_eq_binary),
}
}

/// Perform `left == right` operation on two [`PrimitiveArray`]s.
Expand Down Expand Up @@ -4664,7 +4689,7 @@ mod tests {
}

#[test]
fn test_eq_dyn_dictionary_i8_array() {
fn test_eq_dyn_neq_dyn_dictionary_i8_array() {
// Construct a value array
let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]);

Expand All @@ -4676,10 +4701,17 @@ mod tests {
let result = eq_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
assert_eq!(result.unwrap(), BooleanArray::from(vec![true, false, true]));

let result = neq_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
assert_eq!(
result.unwrap(),
BooleanArray::from(vec![false, true, false])
);
}

#[test]
fn test_eq_dyn_dictionary_u64_array() {
fn test_eq_dyn_neq_dyn_dictionary_u64_array() {
let values = UInt64Array::from_iter_values([10_u64, 11, 12, 13, 14, 15, 16, 17]);

let keys1 = UInt64Array::from_iter_values([1_u64, 3, 4]);
Expand All @@ -4695,10 +4727,14 @@ mod tests {
result.unwrap(),
BooleanArray::from(vec![false, true, false])
);

let result = neq_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
assert_eq!(result.unwrap(), BooleanArray::from(vec![true, false, true]));
}

#[test]
fn test_eq_dyn_dictionary_utf8_array() {
fn test_eq_dyn_neq_dyn_dictionary_utf8_array() {
let test1 = vec!["a", "a", "b", "c"];
let test2 = vec!["a", "b", "b", "c"];

Expand All @@ -4717,10 +4753,17 @@ mod tests {
result.unwrap(),
BooleanArray::from(vec![Some(true), None, None, Some(true)])
);

let result = neq_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
assert_eq!(
result.unwrap(),
BooleanArray::from(vec![Some(false), None, None, Some(false)])
);
}

#[test]
fn test_eq_dyn_dictionary_binary_array() {
fn test_eq_dyn_neq_dyn_dictionary_binary_array() {
let values: BinaryArray = ["hello", "", "parquet"]
.into_iter()
.map(|b| Some(b.as_bytes()))
Expand All @@ -4739,10 +4782,14 @@ mod tests {
result.unwrap(),
BooleanArray::from(vec![true, false, false])
);

let result = neq_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
assert_eq!(result.unwrap(), BooleanArray::from(vec![false, true, true]));
}

#[test]
fn test_eq_dyn_dictionary_interval_array() {
fn test_eq_dyn_neq_dyn_dictionary_interval_array() {
let values = IntervalDayTimeArray::from(vec![1, 6, 10, 2, 3, 5]);

let keys1 = UInt64Array::from_iter_values([1_u64, 0, 3]);
Expand All @@ -4755,10 +4802,17 @@ mod tests {
let result = eq_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
assert_eq!(result.unwrap(), BooleanArray::from(vec![false, true, true]));

let result = neq_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
assert_eq!(
result.unwrap(),
BooleanArray::from(vec![true, false, false])
);
}

#[test]
fn test_eq_dyn_dictionary_date_array() {
fn test_eq_dyn_neq_dyn_dictionary_date_array() {
let values = Date32Array::from(vec![1, 6, 10, 2, 3, 5]);

let keys1 = UInt64Array::from_iter_values([1_u64, 0, 3]);
Expand All @@ -4771,10 +4825,17 @@ mod tests {
let result = eq_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
assert_eq!(result.unwrap(), BooleanArray::from(vec![false, true, true]));

let result = neq_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
assert_eq!(
result.unwrap(),
BooleanArray::from(vec![true, false, false])
);
}

#[test]
fn test_eq_dyn_dictionary_bool_array() {
fn test_eq_dyn_neq_dyn_dictionary_bool_array() {
let values = BooleanArray::from(vec![true, false]);

let keys1 = UInt64Array::from_iter_values([1_u64, 1, 1]);
Expand All @@ -4790,5 +4851,76 @@ mod tests {
result.unwrap(),
BooleanArray::from(vec![false, true, false])
);

let result = neq_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a style thing, I think it is ok to just .unwrap() the result -- if there is a problem it will panic one line later, but I think the source of the problem would still be quite clear

Suggested change
assert!(result.is_ok());

assert_eq!(result.unwrap(), BooleanArray::from(vec![true, false, true]));
}

#[test]
fn test_lt_dyn_gt_dyn_dictionary_i8_array() {
// Construct a value array
let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]);

let keys1 = Int8Array::from_iter_values([3_i8, 4, 4]);
let keys2 = Int8Array::from_iter_values([4_i8, 3, 4]);
let dict_array1 = DictionaryArray::try_new(&keys1, &values).unwrap();
let dict_array2 = DictionaryArray::try_new(&keys2, &values).unwrap();

let result = lt_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
assert_eq!(
result.unwrap(),
BooleanArray::from(vec![true, false, false])
);
viirya marked this conversation as resolved.
Show resolved Hide resolved

let result = lt_eq_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
assert_eq!(result.unwrap(), BooleanArray::from(vec![true, false, true]));

let result = gt_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
assert_eq!(
result.unwrap(),
BooleanArray::from(vec![false, true, false])
);

let result = gt_eq_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
assert_eq!(result.unwrap(), BooleanArray::from(vec![false, true, true]));
}

#[test]
fn test_lt_dyn_gt_dyn_dictionary_bool_array() {
let values = BooleanArray::from(vec![true, false]);

let keys1 = UInt64Array::from_iter_values([1_u64, 1, 0]);
let keys2 = UInt64Array::from_iter_values([0_u64, 1, 1]);
let dict_array1 =
DictionaryArray::<UInt64Type>::try_new(&keys1, &values).unwrap();
let dict_array2 =
DictionaryArray::<UInt64Type>::try_new(&keys2, &values).unwrap();

let result = lt_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
assert_eq!(
result.unwrap(),
BooleanArray::from(vec![true, false, false])
);

let result = lt_eq_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
assert_eq!(result.unwrap(), BooleanArray::from(vec![true, true, false]));

let result = gt_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
assert_eq!(
result.unwrap(),
BooleanArray::from(vec![false, false, true])
);

let result = gt_eq_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
assert_eq!(result.unwrap(), BooleanArray::from(vec![false, true, true]));
}
}