Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

arrow-ord: lt and eq for nested list #5408

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 268 additions & 3 deletions arrow-ord/src/cmp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,13 @@ use arrow_array::{
};
use arrow_buffer::bit_util::ceil;
use arrow_buffer::{BooleanBuffer, MutableBuffer, NullBuffer};
use arrow_schema::ArrowError;
use arrow_schema::{ArrowError, DataType};
use arrow_select::take::take;
use std::cmp::Ordering;
use std::ops::Not;

use crate::ord::build_compare;

#[derive(Debug, Copy, Clone)]
enum Op {
Equal,
Expand Down Expand Up @@ -166,6 +169,48 @@ pub fn not_distinct(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, Ar
compare_op(Op::NotDistinct, lhs, rhs)
}

fn process_nested(
l: &dyn Array,
r: &dyn Array,
op: Op,
l_t: &DataType,
r_t: &DataType,
len: usize,
) -> Result<Option<BooleanArray>, ArrowError> {
use arrow_schema::DataType::*;
if let (List(_), List(_)) = (l_t, r_t) {
fn process_ordering(
l: &dyn Array,
r: &dyn Array,
target_ord: Ordering,
len: usize,
) -> Result<BooleanArray, ArrowError> {
let cmp = build_compare(l, r)?;
let mut values = BooleanArray::builder(len);
for i in 0..len {
let ord = cmp(i, i);
values.append_value(ord == target_ord);
}
Ok(values.finish())
}

// Process nested data types
match op {
Op::Less => Ok(Some(process_ordering(l, r, Ordering::Less, len)?)),
Op::Equal => Ok(Some(process_ordering(l, r, Ordering::Equal, len)?)),
_ => Err(ArrowError::NotYetImplemented(format!(
"Comparison for {op} is NYI"
))),
}
} else if l_t.is_nested() {
Err(ArrowError::NotYetImplemented(format!(
"Comparison for {l_t} is NYI"
)))
} else {
Ok(None)
}
}

/// Perform `op` on the provided `Datum`
#[inline(never)]
fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> {
Expand Down Expand Up @@ -198,12 +243,16 @@ fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
let r = r_v.map(|x| x.values().as_ref()).unwrap_or(r);
let r_t = r.data_type();

if l_t != r_t || l_t.is_nested() {
if l_t != r_t {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This now allows hitting unreachable in the below code block

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I return an error in process_nested, so the nested type that is NYI will not go down there

return Err(ArrowError::InvalidArgumentError(format!(
"Invalid comparison operation: {l_t} {op} {r_t}"
)));
}

if let Some(values) = process_nested(l, r, op, l_t, r_t, len)? {
return Ok(values);
}

// Defer computation as may not be necessary
let values = || -> BooleanBuffer {
let d = downcast_primitive_array! {
Expand Down Expand Up @@ -544,7 +593,11 @@ impl<'a> ArrayOrd for &'a FixedSizeBinaryArray {
mod tests {
use std::sync::Arc;

use arrow_array::{DictionaryArray, Int32Array, Scalar, StringArray};
use arrow_array::{
types::Int32Type, ArrayRef, DictionaryArray, Int32Array, ListArray, Scalar, StringArray,
};
use arrow_buffer::OffsetBuffer;
use arrow_schema::Field;

use super::*;

Expand Down Expand Up @@ -702,4 +755,216 @@ mod tests {

neq(&col.slice(0, col.len() - 1), &col.slice(1, col.len() - 1)).unwrap();
}

Copy link
Contributor

@tustvold tustvold Feb 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to see some tests of

  • Scalar arguments
  • Nulls masking non-empty slices
  • DictionaryArray of ListArray (returning an error would be perfectly valid for this)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

#[test]
fn test_list_lt() {
let l1 = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(0), Some(1), Some(2)]),
None,
Some(vec![Some(3), Some(4), Some(5)]),
]);
let l2 = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(1), Some(1), Some(2)]),
None,
Some(vec![Some(3), Some(4)]),
]);
let res = lt(&l1, &l2).unwrap();
assert_eq!(res, BooleanArray::from(vec![true, false, false]));

let l1 = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(0), Some(1), Some(2)]),
None,
Some(vec![Some(3), Some(4), Some(5)]),
]);
let l2 = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(0), Some(1), Some(2)]),
None,
Some(vec![Some(3), Some(4)]),
]);
let res = lt(&l1, &l2).unwrap();
assert_eq!(res, BooleanArray::from(vec![false, false, false]));

let l1 = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(0), Some(1), Some(2)]),
None,
Some(vec![Some(3), Some(4), Some(5)]),
]);
let l2 = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(0), Some(1), Some(2)]),
None,
Some(vec![Some(3), Some(4), Some(5), Some(7)]),
]);
let res = lt(&l1, &l2).unwrap();
assert_eq!(res, BooleanArray::from(vec![false, false, true]));
}

fn array_into_list_array(arr: ArrayRef) -> ListArray {
let offsets = OffsetBuffer::from_lengths([arr.len()]);
ListArray::new(
Arc::new(Field::new("item", arr.data_type().to_owned(), true)),
offsets,
arr,
None,
)
}

#[test]
fn test_nested_list_lt() {
let l1 = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(0), Some(1), Some(2)]),
None,
Some(vec![Some(3), Some(4), Some(5)]),
]);
let l1 = array_into_list_array(Arc::new(l1));
let l2 = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(1), Some(1), Some(2)]),
None,
Some(vec![Some(3), Some(4)]),
]);
let l2 = array_into_list_array(Arc::new(l2));

let res = lt(&l1, &l2).unwrap();
assert_eq!(res, BooleanArray::from(vec![true]));

let l1 = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(1), Some(1), Some(2)]),
None,
Some(vec![Some(3), Some(4), Some(5)]),
]);
let l1 = array_into_list_array(Arc::new(l1));
let l2 = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(1), Some(1), Some(2)]),
None,
Some(vec![Some(3), Some(4)]),
]);
let l2 = array_into_list_array(Arc::new(l2));

let res = lt(&l1, &l2).unwrap();
assert_eq!(res, BooleanArray::from(vec![false]));

let l1 = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(0), Some(1), Some(2)]),
None,
Some(vec![Some(3), Some(4), Some(5)]),
]);
let l1 = array_into_list_array(Arc::new(l1));
let l2 = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(0), Some(1), Some(2)]),
None,
Some(vec![Some(3), Some(4), Some(5), Some(7)]),
]);
let l2 = array_into_list_array(Arc::new(l2));

let res = lt(&l1, &l2).unwrap();
assert_eq!(res, BooleanArray::from(vec![true]));
}

#[test]
fn test_list_eq() {
let l1 = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(0), Some(1), Some(2)]),
None,
Some(vec![Some(3), Some(4), Some(5)]),
]);
let l2 = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(1), Some(1), Some(2)]),
None,
Some(vec![Some(3), Some(4)]),
]);
let res = eq(&l1, &l2).unwrap();
assert_eq!(res, BooleanArray::from(vec![false, true, false]));

let l1 = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(0), Some(1), Some(2)]),
None,
Some(vec![Some(3), Some(4), Some(5)]),
]);
let l2 = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(0), Some(1), Some(2)]),
None,
Some(vec![Some(3), Some(4)]),
]);
let res = eq(&l1, &l2).unwrap();
assert_eq!(res, BooleanArray::from(vec![true, true, false]));

let l1 = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(0), Some(1), Some(2)]),
None,
Some(vec![Some(3), Some(4), Some(5)]),
]);
let l2 = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(0), Some(1), Some(2)]),
None,
Some(vec![Some(3), Some(4), Some(5), Some(7)]),
]);
let res = eq(&l1, &l2).unwrap();
assert_eq!(res, BooleanArray::from(vec![true, true, false]));
}

#[test]
fn test_nested_list_eq() {
let l1 = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(0), Some(1), Some(2)]),
None,
Some(vec![Some(3), Some(4), Some(5)]),
]);
let l1 = array_into_list_array(Arc::new(l1));
let l2 = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(1), Some(1), Some(2)]),
None,
Some(vec![Some(3), Some(4)]),
]);
let l2 = array_into_list_array(Arc::new(l2));

let res = eq(&l1, &l2).unwrap();
assert_eq!(res, BooleanArray::from(vec![false]));

let l1 = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(1), Some(1), Some(2)]),
None,
Some(vec![Some(3), Some(4), Some(5)]),
]);
let l1 = array_into_list_array(Arc::new(l1));
let l2 = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(1), Some(1), Some(2)]),
None,
Some(vec![Some(3), Some(4)]),
]);
let l2 = array_into_list_array(Arc::new(l2));

let res = eq(&l1, &l2).unwrap();
assert_eq!(res, BooleanArray::from(vec![false]));

let l1 = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(0), Some(1), Some(2)]),
None,
Some(vec![Some(3), Some(4), Some(5)]),
]);
let l1 = array_into_list_array(Arc::new(l1));
let l2 = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(0), Some(1), Some(2)]),
None,
Some(vec![Some(3), Some(4), Some(5), Some(7)]),
]);
let l2 = array_into_list_array(Arc::new(l2));

let res = eq(&l1, &l2).unwrap();
assert_eq!(res, BooleanArray::from(vec![false]));

let l1 = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(0), Some(1), Some(2)]),
None,
Some(vec![Some(3), Some(4), Some(5)]),
]);
let l1 = array_into_list_array(Arc::new(l1));
let l2 = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(0), Some(1), Some(2)]),
None,
Some(vec![Some(3), Some(4), Some(5)]),
]);
let l2 = array_into_list_array(Arc::new(l2));

let res = eq(&l1, &l2).unwrap();
assert_eq!(res, BooleanArray::from(vec![true]));
}
}
29 changes: 29 additions & 0 deletions arrow-ord/src/ord.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,32 @@ use arrow_buffer::ArrowNativeType;
use arrow_schema::ArrowError;
use std::cmp::Ordering;

pub fn list_ordering(l: ArrayRef, r: ArrayRef) -> Ordering {
let l_len = l.len();
let r_len = r.len();
let min_len = std::cmp::min(l_len, r_len);

let cmp = build_compare(&l, &r).unwrap();

for i in 0..min_len {
let ord = cmp(i, i);
if ord != Ordering::Equal {
return ord;
}
}

l_len.cmp(&r_len)
}

/// Compare the values at two arbitrary indices in two arrays.
pub type DynComparator = Box<dyn Fn(usize, usize) -> Ordering + Send + Sync>;

fn compare_list(l: &dyn Array, r: &dyn Array) -> DynComparator {
let l = l.as_list::<i32>().to_owned();
let r = r.as_list::<i32>().to_owned();
Box::new(move |i, j| list_ordering(l.value(i), r.value(j)))
}

fn compare_primitive<T: ArrowPrimitiveType>(left: &dyn Array, right: &dyn Array) -> DynComparator
where
T::Native: ArrowNativeTypeOp,
Expand Down Expand Up @@ -121,6 +144,12 @@ pub fn build_compare(left: &dyn Array, right: &dyn Array) -> Result<DynComparato
_ => unreachable!()
}
},
(List(_), List(_)) => {
Ok(compare_list(left, right))
// let l = left.as_list::<i32>().to_owned();
// let r = right.as_list::<i32>().to_owned();
// Ok(Box::new(move |i, j| compare_list(l.value(i), r.value(j))))
},
(lhs, rhs) => Err(ArrowError::InvalidArgumentError(match lhs == rhs {
true => format!("The data type type {lhs:?} has no natural order"),
false => "Can't compare arrays of different types".to_string(),
Expand Down
Loading