Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Push SortOptions into DynComparator Allowing Nested Comparisons (#5426) #5792

Merged
merged 5 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 172 additions & 4 deletions arrow-ord/src/ord.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ pub fn build_compare(left: &dyn Array, right: &dyn Array) -> Result<DynComparato
/// Returns a comparison function that compares two values at two different positions
/// between the two arrays.
///
/// For comparing arrays element-wise, see also the vectorised kernels in [`crate::cmp`].
///
/// If `nulls_first` is true `NULL` values will be considered less than any non-null value,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might help to point people at the faster kernels too

like "see kernels in ord::cmp for fast kernels

/// otherwise they will be considered greater.
///
Expand Down Expand Up @@ -289,10 +291,21 @@ pub fn build_compare(left: &dyn Array, right: &dyn Array) -> Result<DynComparato
///
/// # Postgres-compatible Nested Comparison
///
/// Whilst SQL prescribes ternary logic for nulls, that is comparing a value against a null yields
/// a NULL, many systems, including postgres, instead apply a total ordering to comparison
/// of nested nulls. That is nulls within nested types are either greater than any value,
/// or less than any value (Spark). This could be implemented as below
/// Whilst SQL prescribes ternary logic for nulls, that is comparing a value against a NULL yields
/// a NULL, many systems, including postgres, instead apply a total ordering to comparison of
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

/// nested nulls. That is nulls within nested types are either greater than any value (postgres),
/// or less than any value (Spark).
///
/// In particular
///
/// ```ignore
/// { a: 1, b: null } == { a: 1, b: null } => true
/// { a: 1, b: null } == { a: 1, b: 1 } => false
/// { a: 1, b: null } == null => null
/// null == null => null
/// ```
///
/// This could be implemented as below
///
/// ```
/// # use arrow_array::{Array, BooleanArray};
Expand Down Expand Up @@ -366,7 +379,9 @@ pub fn make_comparator(
#[cfg(test)]
pub mod tests {
use super::*;
use arrow_array::builder::{Int32Builder, ListBuilder};
use arrow_buffer::{i256, IntervalDayTime, OffsetBuffer};
use arrow_schema::{DataType, Field, Fields};
use half::f16;
use std::sync::Arc;

Expand Down Expand Up @@ -736,4 +751,157 @@ pub mod tests {
test_bytes_impl::<BinaryType>();
test_bytes_impl::<LargeBinaryType>();
}

#[test]
fn test_lists() {
let mut a = ListBuilder::new(ListBuilder::new(Int32Builder::new()));
a.extend([
Some(vec![Some(vec![Some(1), Some(2), None]), Some(vec![None])]),
Some(vec![
Some(vec![Some(1), Some(2), Some(3)]),
Some(vec![Some(1)]),
]),
Some(vec![]),
]);
let a = a.finish();
let mut b = ListBuilder::new(ListBuilder::new(Int32Builder::new()));
b.extend([
Some(vec![Some(vec![Some(1), Some(2), None]), Some(vec![None])]),
Some(vec![
Some(vec![Some(1), Some(2), None]),
Some(vec![Some(1)]),
]),
Some(vec![
Some(vec![Some(1), Some(2), Some(3), Some(4)]),
Some(vec![Some(1)]),
]),
None,
]);
let b = b.finish();

let opts = SortOptions {
descending: false,
nulls_first: true,
};
let cmp = make_comparator(&a, &b, opts).unwrap();
assert_eq!(cmp(0, 0), Ordering::Equal);
assert_eq!(cmp(0, 1), Ordering::Less);
assert_eq!(cmp(0, 2), Ordering::Less);
assert_eq!(cmp(1, 2), Ordering::Less);
assert_eq!(cmp(1, 3), Ordering::Greater);
assert_eq!(cmp(2, 0), Ordering::Less);

let opts = SortOptions {
descending: true,
nulls_first: true,
};
let cmp = make_comparator(&a, &b, opts).unwrap();
assert_eq!(cmp(0, 0), Ordering::Equal);
assert_eq!(cmp(0, 1), Ordering::Less);
assert_eq!(cmp(0, 2), Ordering::Less);
assert_eq!(cmp(1, 2), Ordering::Greater);
assert_eq!(cmp(1, 3), Ordering::Greater);
assert_eq!(cmp(2, 0), Ordering::Greater);

let opts = SortOptions {
descending: true,
nulls_first: false,
};
let cmp = make_comparator(&a, &b, opts).unwrap();
assert_eq!(cmp(0, 0), Ordering::Equal);
assert_eq!(cmp(0, 1), Ordering::Greater);
assert_eq!(cmp(0, 2), Ordering::Greater);
assert_eq!(cmp(1, 2), Ordering::Greater);
assert_eq!(cmp(1, 3), Ordering::Less);
assert_eq!(cmp(2, 0), Ordering::Greater);

let opts = SortOptions {
descending: false,
nulls_first: false,
};
let cmp = make_comparator(&a, &b, opts).unwrap();
assert_eq!(cmp(0, 0), Ordering::Equal);
assert_eq!(cmp(0, 1), Ordering::Greater);
assert_eq!(cmp(0, 2), Ordering::Greater);
assert_eq!(cmp(1, 2), Ordering::Less);
assert_eq!(cmp(1, 3), Ordering::Less);
assert_eq!(cmp(2, 0), Ordering::Less);
}

#[test]
fn test_struct() {
let fields = Fields::from(vec![
Field::new("a", DataType::Int32, true),
Field::new_list("b", Field::new("item", DataType::Int32, true), true),
]);

let a = Int32Array::from(vec![Some(1), Some(2), None, None]);
let mut b = ListBuilder::new(Int32Builder::new());
b.extend([Some(vec![Some(1), Some(2)]), Some(vec![None]), None, None]);
let b = b.finish();

let nulls = Some(NullBuffer::from_iter([true, true, true, false]));
let values = vec![Arc::new(a) as _, Arc::new(b) as _];
let s1 = StructArray::new(fields.clone(), values, nulls);

let a = Int32Array::from(vec![None, Some(2), None]);
let mut b = ListBuilder::new(Int32Builder::new());
b.extend([None, None, Some(vec![])]);
let b = b.finish();

let values = vec![Arc::new(a) as _, Arc::new(b) as _];
let s2 = StructArray::new(fields.clone(), values, None);

let opts = SortOptions {
descending: false,
nulls_first: true,
};
let cmp = make_comparator(&s1, &s2, opts).unwrap();
assert_eq!(cmp(0, 1), Ordering::Less); // (1, [1, 2]) cmp (2, None)
assert_eq!(cmp(0, 0), Ordering::Greater); // (1, [1, 2]) cmp (None, None)
assert_eq!(cmp(1, 1), Ordering::Greater); // (2, [None]) cmp (2, None)
assert_eq!(cmp(2, 2), Ordering::Less); // (None, None) cmp (None, [])
assert_eq!(cmp(3, 0), Ordering::Less); // None cmp (None, [])
assert_eq!(cmp(2, 0), Ordering::Equal); // (None, None) cmp (None, None)
assert_eq!(cmp(3, 0), Ordering::Less); // None cmp (None, None)

let opts = SortOptions {
descending: true,
nulls_first: true,
};
let cmp = make_comparator(&s1, &s2, opts).unwrap();
assert_eq!(cmp(0, 1), Ordering::Greater); // (1, [1, 2]) cmp (2, None)
assert_eq!(cmp(0, 0), Ordering::Greater); // (1, [1, 2]) cmp (None, None)
assert_eq!(cmp(1, 1), Ordering::Greater); // (2, [None]) cmp (2, None)
assert_eq!(cmp(2, 2), Ordering::Less); // (None, None) cmp (None, [])
assert_eq!(cmp(3, 0), Ordering::Less); // None cmp (None, [])
assert_eq!(cmp(2, 0), Ordering::Equal); // (None, None) cmp (None, None)
assert_eq!(cmp(3, 0), Ordering::Less); // None cmp (None, None)

let opts = SortOptions {
descending: true,
nulls_first: false,
};
let cmp = make_comparator(&s1, &s2, opts).unwrap();
assert_eq!(cmp(0, 1), Ordering::Greater); // (1, [1, 2]) cmp (2, None)
assert_eq!(cmp(0, 0), Ordering::Less); // (1, [1, 2]) cmp (None, None)
assert_eq!(cmp(1, 1), Ordering::Less); // (2, [None]) cmp (2, None)
assert_eq!(cmp(2, 2), Ordering::Greater); // (None, None) cmp (None, [])
assert_eq!(cmp(3, 0), Ordering::Greater); // None cmp (None, [])
assert_eq!(cmp(2, 0), Ordering::Equal); // (None, None) cmp (None, None)
assert_eq!(cmp(3, 0), Ordering::Greater); // None cmp (None, None)

let opts = SortOptions {
descending: false,
nulls_first: false,
};
let cmp = make_comparator(&s1, &s2, opts).unwrap();
assert_eq!(cmp(0, 1), Ordering::Less); // (1, [1, 2]) cmp (2, None)
assert_eq!(cmp(0, 0), Ordering::Less); // (1, [1, 2]) cmp (None, None)
assert_eq!(cmp(1, 1), Ordering::Less); // (2, [None]) cmp (2, None)
assert_eq!(cmp(2, 2), Ordering::Greater); // (None, None) cmp (None, [])
assert_eq!(cmp(3, 0), Ordering::Greater); // None cmp (None, [])
assert_eq!(cmp(2, 0), Ordering::Equal); // (None, None) cmp (None, None)
assert_eq!(cmp(3, 0), Ordering::Greater); // None cmp (None, None)
}
}
53 changes: 48 additions & 5 deletions arrow-row/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1302,9 +1302,9 @@ mod tests {
use arrow_array::builder::*;
use arrow_array::types::*;
use arrow_array::*;
use arrow_buffer::i256;
use arrow_buffer::Buffer;
use arrow_cast::display::array_value_to_string;
use arrow_buffer::{i256, NullBuffer};
use arrow_buffer::{Buffer, OffsetBuffer};
use arrow_cast::display::{ArrayFormatter, FormatOptions};
use arrow_ord::sort::{LexicographicalComparator, SortColumn};

use super::*;
Expand Down Expand Up @@ -2099,9 +2099,35 @@ mod tests {
builder.finish()
}

fn generate_struct(len: usize, valid_percent: f64) -> StructArray {
let mut rng = thread_rng();
let nulls = NullBuffer::from_iter((0..len).map(|_| rng.gen_bool(valid_percent)));
let a = generate_primitive_array::<Int32Type>(len, valid_percent);
let b = generate_strings::<i32>(len, valid_percent);
let fields = Fields::from(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Utf8, true),
]);
let values = vec![Arc::new(a) as _, Arc::new(b) as _];
StructArray::new(fields, values, Some(nulls))
}

fn generate_list<F>(len: usize, valid_percent: f64, values: F) -> ListArray
where
F: FnOnce(usize) -> ArrayRef,
{
let mut rng = thread_rng();
let offsets = OffsetBuffer::<i32>::from_lengths((0..len).map(|_| rng.gen_range(0..10)));
let values_len = offsets.last().unwrap().to_usize().unwrap();
let values = values(values_len);
let nulls = NullBuffer::from_iter((0..len).map(|_| rng.gen_bool(valid_percent)));
let field = Arc::new(Field::new("item", values.data_type().clone(), true));
ListArray::new(field, offsets, values, Some(nulls))
}

fn generate_column(len: usize) -> ArrayRef {
let mut rng = thread_rng();
match rng.gen_range(0..10) {
match rng.gen_range(0..14) {
0 => Arc::new(generate_primitive_array::<Int32Type>(len, 0.8)),
1 => Arc::new(generate_primitive_array::<UInt32Type>(len, 0.8)),
2 => Arc::new(generate_primitive_array::<Int64Type>(len, 0.8)),
Expand All @@ -2125,14 +2151,31 @@ mod tests {
0.8,
)),
9 => Arc::new(generate_fixed_size_binary(len, 0.8)),
10 => Arc::new(generate_struct(len, 0.8)),
11 => Arc::new(generate_list(len, 0.8, |values_len| {
Arc::new(generate_primitive_array::<Int64Type>(values_len, 0.8))
})),
12 => Arc::new(generate_list(len, 0.8, |values_len| {
Arc::new(generate_strings::<i32>(values_len, 0.8))
})),
13 => Arc::new(generate_list(len, 0.8, |values_len| {
Arc::new(generate_struct(values_len, 0.8))
})),
_ => unreachable!(),
}
}

fn print_row(cols: &[SortColumn], row: usize) -> String {
let t: Vec<_> = cols
.iter()
.map(|x| array_value_to_string(&x.values, row).unwrap())
.map(|x| match x.values.is_valid(row) {
true => {
let opts = FormatOptions::default().with_null("NULL");
let formatter = ArrayFormatter::try_new(x.values.as_ref(), &opts).unwrap();
formatter.value(row).to_string()
}
false => String::new(),
})
.collect();
t.join(",")
}
Expand Down
Loading