From 598c2f15e12a4d9435ceba7b9d8774f35712dd0b Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Tue, 21 May 2024 17:22:28 +0100 Subject: [PATCH] Push SortOptions into DynComparator (#5426) --- arrow-ord/src/ord.rs | 382 ++++++++++++++++++++++++++++++++++-------- arrow-ord/src/sort.rs | 114 ++----------- 2 files changed, 323 insertions(+), 173 deletions(-) diff --git a/arrow-ord/src/ord.rs b/arrow-ord/src/ord.rs index 8f21cd7c498d..c660d4778f8e 100644 --- a/arrow-ord/src/ord.rs +++ b/arrow-ord/src/ord.rs @@ -20,36 +20,117 @@ use arrow_array::cast::AsArray; use arrow_array::types::*; use arrow_array::*; -use arrow_buffer::ArrowNativeType; -use arrow_schema::ArrowError; +use arrow_buffer::{ArrowNativeType, NullBuffer}; +use arrow_schema::{ArrowError, SortOptions}; use std::cmp::Ordering; /// Compare the values at two arbitrary indices in two arrays. pub type DynComparator = Box Ordering + Send + Sync>; -fn compare_primitive(left: &dyn Array, right: &dyn Array) -> DynComparator +/// If parent sort order is descending we need to invert the value of nulls_first so that +/// when the parent is sorted based on the produced ranks, nulls are still ordered correctly +fn child_opts(opts: SortOptions) -> SortOptions { + SortOptions { + descending: false, + nulls_first: opts.nulls_first != opts.descending, + } +} + +fn compare(l: &A, r: &A, opts: SortOptions, cmp: F) -> DynComparator where - T::Native: ArrowNativeTypeOp, + A: Array + Clone, + F: Fn(usize, usize) -> Ordering + Send + Sync + 'static, { - let left = left.as_primitive::().clone(); - let right = right.as_primitive::().clone(); - Box::new(move |i, j| left.value(i).compare(right.value(j))) + let l = l.logical_nulls().filter(|x| x.null_count() > 0); + let r = r.logical_nulls().filter(|x| x.null_count() > 0); + match (opts.nulls_first, opts.descending) { + (true, true) => compare_impl::(l, r, cmp), + (true, false) => compare_impl::(l, r, cmp), + (false, true) => compare_impl::(l, r, cmp), + (false, false) => compare_impl::(l, r, cmp), + } } -fn compare_boolean(left: &dyn Array, right: &dyn Array) -> DynComparator { - let left: BooleanArray = left.as_boolean().clone(); - let right: BooleanArray = right.as_boolean().clone(); +fn compare_impl( + l: Option, + r: Option, + cmp: F, +) -> DynComparator +where + F: Fn(usize, usize) -> Ordering + Send + Sync + 'static, +{ + let cmp = move |i, j| match DESCENDING { + true => cmp(i, j).reverse(), + false => cmp(i, j), + }; + + let (left_null, right_null) = match NULLS_FIRST { + true => (Ordering::Less, Ordering::Greater), + false => (Ordering::Greater, Ordering::Less), + }; + + match (l, r) { + (None, None) => Box::new(cmp), + (Some(l), None) => Box::new(move |i, j| match l.is_null(i) { + true => left_null, + false => cmp(i, j), + }), + (None, Some(r)) => Box::new(move |i, j| match r.is_null(j) { + true => right_null, + false => cmp(i, j), + }), + (Some(l), Some(r)) => Box::new(move |i, j| match (l.is_null(i), r.is_null(j)) { + (true, true) => Ordering::Equal, + (true, false) => left_null, + (false, true) => right_null, + (false, false) => cmp(i, j), + }), + } +} + +fn compare_primitive( + left: &dyn Array, + right: &dyn Array, + opts: SortOptions, +) -> DynComparator +where + T::Native: ArrowNativeTypeOp, +{ + let left = left.as_primitive::(); + let right = right.as_primitive::(); + let l_values = left.values().clone(); + let r_values = right.values().clone(); - Box::new(move |i, j| left.value(i).cmp(&right.value(j))) + compare(&left, &right, opts, move |i, j| { + l_values[i].compare(r_values[j]) + }) } -fn compare_bytes(left: &dyn Array, right: &dyn Array) -> DynComparator { - let left = left.as_bytes::().clone(); - let right = right.as_bytes::().clone(); +fn compare_boolean(left: &dyn Array, right: &dyn Array, opts: SortOptions) -> DynComparator { + let left = left.as_boolean(); + let right = right.as_boolean(); + + let l_values = left.values().clone(); + let r_values = right.values().clone(); + + compare(left, right, opts, move |i, j| { + l_values.value(i).cmp(&r_values.value(j)) + }) +} - Box::new(move |i, j| { - let l: &[u8] = left.value(i).as_ref(); - let r: &[u8] = right.value(j).as_ref(); +fn compare_bytes( + left: &dyn Array, + right: &dyn Array, + opts: SortOptions, +) -> DynComparator { + let left = left.as_bytes::(); + let right = right.as_bytes::(); + + let l = left.clone(); + let r = right.clone(); + compare(left, right, opts, move |i, j| { + let l: &[u8] = l.value(i).as_ref(); + let r: &[u8] = r.value(j).as_ref(); l.cmp(r) }) } @@ -57,67 +138,220 @@ fn compare_bytes(left: &dyn Array, right: &dyn Array) -> DynCo fn compare_dict( left: &dyn Array, right: &dyn Array, + opts: SortOptions, ) -> Result { let left = left.as_dictionary::(); let right = right.as_dictionary::(); - let cmp = build_compare(left.values().as_ref(), right.values().as_ref())?; - let left_keys = left.keys().clone(); - let right_keys = right.keys().clone(); + let cmp = make_comparator(left.values().as_ref(), right.values().as_ref(), opts)?; + let left_keys = left.keys().values().clone(); + let right_keys = right.keys().values().clone(); - // TODO: Handle value nulls (#2687) - Ok(Box::new(move |i, j| { - let l = left_keys.value(i).as_usize(); - let r = right_keys.value(j).as_usize(); + let f = compare(left, right, opts, move |i, j| { + let l = left_keys[i].as_usize(); + let r = right_keys[j].as_usize(); cmp(l, r) - })) + }); + Ok(f) +} + +fn compare_list( + left: &dyn Array, + right: &dyn Array, + opts: SortOptions, +) -> Result { + let left = left.as_list::(); + let right = right.as_list::(); + + let c_opts = child_opts(opts); + let cmp = make_comparator(left.values().as_ref(), right.values().as_ref(), c_opts)?; + + let l_o = left.offsets().clone(); + let r_o = right.offsets().clone(); + let f = compare(left, right, opts, move |i, j| { + let l_end = l_o[i + 1].as_usize(); + let l_start = l_o[i].as_usize(); + + let r_end = r_o[j + 1].as_usize(); + let r_start = r_o[j].as_usize(); + + for (i, j) in (l_start..l_end).zip(r_start..r_end) { + match cmp(i, j) { + Ordering::Equal => continue, + r => return r, + } + } + (l_end - l_start).cmp(&(r_end - r_start)) + }); + Ok(f) +} + +fn compare_fixed_list( + left: &dyn Array, + right: &dyn Array, + opts: SortOptions, +) -> Result { + let left = left.as_fixed_size_list(); + let right = right.as_fixed_size_list(); + + let c_opts = child_opts(opts); + let cmp = make_comparator(left.values().as_ref(), right.values().as_ref(), c_opts)?; + + let l_size = left.value_length().to_usize().unwrap(); + let r_size = right.value_length().to_usize().unwrap(); + let size_cmp = l_size.cmp(&r_size); + + let f = compare(left, right, opts, move |i, j| { + let l_start = i * l_size; + let l_end = l_start + l_size; + let r_start = j * r_size; + let r_end = r_start + r_size; + for (i, j) in (l_start..l_end).zip(r_start..r_end) { + match cmp(i, j) { + Ordering::Equal => continue, + r => return r, + } + } + size_cmp + }); + Ok(f) } -/// returns a comparison function that compares two values at two different positions +fn compare_struct( + left: &dyn Array, + right: &dyn Array, + opts: SortOptions, +) -> Result { + let left = left.as_struct(); + let right = right.as_struct(); + + if left.columns().len() != right.columns().len() { + return Err(ArrowError::InvalidArgumentError( + "Cannot compare StructArray with different number of columns".to_string(), + )); + } + + let c_opts = child_opts(opts); + let columns = left.columns().iter().zip(right.columns()); + let comparators = columns + .map(|(l, r)| make_comparator(l, r, c_opts)) + .collect::, _>>()?; + + let f = compare(left, right, opts, move |i, j| { + for cmp in &comparators { + match cmp(i, j) { + Ordering::Equal => continue, + r => return r, + } + } + Ordering::Equal + }); + Ok(f) +} + +#[deprecated(note = "Use make_comparator")] +#[doc(hidden)] +pub fn build_compare(left: &dyn Array, right: &dyn Array) -> Result { + make_comparator(left, right, SortOptions::default()) +} + +/// Returns a comparison function that compares two values at two different positions /// between the two arrays. -/// The arrays' types must be equal. -/// # Example -/// ``` -/// use arrow_array::Int32Array; -/// use arrow_ord::ord::build_compare; /// +/// If `nulls_first` is true `NULL` values will be considered less than any non-null value, +/// otherwise they will be considered greater. +/// +/// # Basic Usage +/// +/// ``` +/// # use std::cmp::Ordering; +/// # use arrow_array::Int32Array; +/// # use arrow_ord::ord::make_comparator; +/// # use arrow_schema::SortOptions; +/// # /// let array1 = Int32Array::from(vec![1, 2]); /// let array2 = Int32Array::from(vec![3, 4]); /// -/// let cmp = build_compare(&array1, &array2).unwrap(); -/// +/// let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap(); /// // 1 (index 0 of array1) is smaller than 4 (index 1 of array2) -/// assert_eq!(std::cmp::Ordering::Less, cmp(0, 1)); +/// assert_eq!(cmp(0, 1), Ordering::Less); +/// +/// let array1 = Int32Array::from(vec![Some(1), None]); +/// let array2 = Int32Array::from(vec![None, Some(2)]); +/// let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap(); +/// +/// assert_eq!(cmp(0, 1), Ordering::Less); // Some(1) vs Some(2) +/// assert_eq!(cmp(1, 1), Ordering::Less); // None vs Some(2) +/// assert_eq!(cmp(1, 0), Ordering::Equal); // None vs None +/// assert_eq!(cmp(0, 0), Ordering::Greater); // Some(1) vs None /// ``` -// This is a factory of comparisons. -// The lifetime 'a enforces that we cannot use the closure beyond any of the array's lifetime. -pub fn build_compare(left: &dyn Array, right: &dyn Array) -> Result { +/// +/// # Postgres-compatible Nested Comparison +/// +/// Whilst SQL prescribes ternary logic for nulls, that is comparing a value against a null yields +/// a NULL, many systems, including postgres, instead apply a total ordering to comparison +/// of nested nulls. That is nulls within nested types are either greater than any value, +/// or less than any value (Spark). This could be implemented as below +/// +/// ``` +/// # use arrow_array::{Array, BooleanArray}; +/// # use arrow_buffer::NullBuffer; +/// # use arrow_ord::cmp; +/// # use arrow_ord::ord::make_comparator; +/// # use arrow_schema::{ArrowError, SortOptions}; +/// fn eq(a: &dyn Array, b: &dyn Array) -> Result { +/// if !a.data_type().is_nested() { +/// return cmp::eq(a, b); // Use faster vectorised kernel +/// } +/// +/// let cmp = make_comparator(a, b, SortOptions::default())?; +/// let len = a.len().min(b.len()); +/// let values = (0..len).map(|i| cmp(i, i).is_eq()).collect(); +/// let nulls = NullBuffer::union(a.nulls(), b.nulls()); +/// Ok(BooleanArray::new(values, nulls)) +/// } +/// ```` +pub fn make_comparator( + left: &dyn Array, + right: &dyn Array, + opts: SortOptions, +) -> Result { use arrow_schema::DataType::*; + macro_rules! primitive_helper { - ($t:ty, $left:expr, $right:expr) => { - Ok(compare_primitive::<$t>($left, $right)) + ($t:ty, $left:expr, $right:expr, $nulls_first:expr) => { + Ok(compare_primitive::<$t>($left, $right, $nulls_first)) }; } downcast_primitive! { - left.data_type(), right.data_type() => (primitive_helper, left, right), - (Boolean, Boolean) => Ok(compare_boolean(left, right)), - (Utf8, Utf8) => Ok(compare_bytes::(left, right)), - (LargeUtf8, LargeUtf8) => Ok(compare_bytes::(left, right)), - (Binary, Binary) => Ok(compare_bytes::(left, right)), - (LargeBinary, LargeBinary) => Ok(compare_bytes::(left, right)), + left.data_type(), right.data_type() => (primitive_helper, left, right, opts), + (Boolean, Boolean) => Ok(compare_boolean(left, right, opts)), + (Utf8, Utf8) => Ok(compare_bytes::(left, right, opts)), + (LargeUtf8, LargeUtf8) => Ok(compare_bytes::(left, right, opts)), + (Binary, Binary) => Ok(compare_bytes::(left, right, opts)), + (LargeBinary, LargeBinary) => Ok(compare_bytes::(left, right, opts)), (FixedSizeBinary(_), FixedSizeBinary(_)) => { - let left = left.as_fixed_size_binary().clone(); - let right = right.as_fixed_size_binary().clone(); - Ok(Box::new(move |i, j| left.value(i).cmp(right.value(j)))) + let left = left.as_fixed_size_binary(); + let right = right.as_fixed_size_binary(); + + let l = left.clone(); + let r = right.clone(); + Ok(compare(left, right, opts, move |i, j| { + l.value(i).cmp(r.value(j)) + })) }, + (List(_), List(_)) => compare_list::(left, right, opts), + (LargeList(_), LargeList(_)) => compare_list::(left, right, opts), + (FixedSizeList(_, _), FixedSizeList(_, _)) => compare_fixed_list(left, right, opts), + (Struct(_), Struct(_)) => compare_struct(left, right, opts), (Dictionary(l_key, _), Dictionary(r_key, _)) => { macro_rules! dict_helper { - ($t:ty, $left:expr, $right:expr) => { - compare_dict::<$t>($left, $right) + ($t:ty, $left:expr, $right:expr, $opts: expr) => { + compare_dict::<$t>($left, $right, $opts) }; } downcast_integer! { - l_key.as_ref(), r_key.as_ref() => (dict_helper, left, right), + l_key.as_ref(), r_key.as_ref() => (dict_helper, left, right, opts), _ => unreachable!() } }, @@ -140,7 +374,7 @@ pub mod tests { let items = vec![vec![1u8], vec![2u8]]; let array = FixedSizeBinaryArray::try_from_iter(items.into_iter()).unwrap(); - let cmp = build_compare(&array, &array).unwrap(); + let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap(); assert_eq!(Ordering::Less, cmp(0, 1)); } @@ -152,7 +386,7 @@ pub mod tests { let items = vec![vec![2u8]]; let array2 = FixedSizeBinaryArray::try_from_iter(items.into_iter()).unwrap(); - let cmp = build_compare(&array1, &array2).unwrap(); + let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap(); assert_eq!(Ordering::Less, cmp(0, 0)); } @@ -161,7 +395,7 @@ pub mod tests { fn test_i32() { let array = Int32Array::from(vec![1, 2]); - let cmp = build_compare(&array, &array).unwrap(); + let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap(); assert_eq!(Ordering::Less, (cmp)(0, 1)); } @@ -171,7 +405,7 @@ pub mod tests { let array1 = Int32Array::from(vec![1]); let array2 = Int32Array::from(vec![2]); - let cmp = build_compare(&array1, &array2).unwrap(); + let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap(); assert_eq!(Ordering::Less, cmp(0, 0)); } @@ -180,7 +414,7 @@ pub mod tests { fn test_f16() { let array = Float16Array::from(vec![f16::from_f32(1.0), f16::from_f32(2.0)]); - let cmp = build_compare(&array, &array).unwrap(); + let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap(); assert_eq!(Ordering::Less, cmp(0, 1)); } @@ -189,7 +423,7 @@ pub mod tests { fn test_f64() { let array = Float64Array::from(vec![1.0, 2.0]); - let cmp = build_compare(&array, &array).unwrap(); + let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap(); assert_eq!(Ordering::Less, cmp(0, 1)); } @@ -198,7 +432,7 @@ pub mod tests { fn test_f64_nan() { let array = Float64Array::from(vec![1.0, f64::NAN]); - let cmp = build_compare(&array, &array).unwrap(); + let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap(); assert_eq!(Ordering::Less, cmp(0, 1)); assert_eq!(Ordering::Equal, cmp(1, 1)); @@ -208,7 +442,7 @@ pub mod tests { fn test_f64_zeros() { let array = Float64Array::from(vec![-0.0, 0.0]); - let cmp = build_compare(&array, &array).unwrap(); + let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap(); assert_eq!(Ordering::Less, cmp(0, 1)); assert_eq!(Ordering::Greater, cmp(1, 0)); @@ -225,7 +459,7 @@ pub mod tests { IntervalDayTimeType::make_value(0, 90_000_000), ]); - let cmp = build_compare(&array, &array).unwrap(); + let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap(); assert_eq!(Ordering::Less, cmp(0, 1)); assert_eq!(Ordering::Greater, cmp(1, 0)); @@ -248,7 +482,7 @@ pub mod tests { IntervalYearMonthType::make_value(1, 1), ]); - let cmp = build_compare(&array, &array).unwrap(); + let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap(); assert_eq!(Ordering::Less, cmp(0, 1)); assert_eq!(Ordering::Greater, cmp(1, 0)); @@ -269,7 +503,7 @@ pub mod tests { IntervalMonthDayNanoType::make_value(0, 100, 2), ]); - let cmp = build_compare(&array, &array).unwrap(); + let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap(); assert_eq!(Ordering::Less, cmp(0, 1)); assert_eq!(Ordering::Greater, cmp(1, 0)); @@ -289,7 +523,7 @@ pub mod tests { .with_precision_and_scale(23, 6) .unwrap(); - let cmp = build_compare(&array, &array).unwrap(); + let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap(); assert_eq!(Ordering::Less, cmp(1, 0)); assert_eq!(Ordering::Greater, cmp(0, 2)); } @@ -306,7 +540,7 @@ pub mod tests { .with_precision_and_scale(53, 6) .unwrap(); - let cmp = build_compare(&array, &array).unwrap(); + let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap(); assert_eq!(Ordering::Less, cmp(1, 0)); assert_eq!(Ordering::Greater, cmp(0, 2)); } @@ -316,7 +550,7 @@ pub mod tests { let data = vec!["a", "b", "c", "a", "a", "c", "c"]; let array = data.into_iter().collect::>(); - let cmp = build_compare(&array, &array).unwrap(); + let cmp = make_comparator(&array, &array, SortOptions::default()).unwrap(); assert_eq!(Ordering::Less, cmp(0, 1)); assert_eq!(Ordering::Equal, cmp(3, 4)); @@ -330,7 +564,7 @@ pub mod tests { let d2 = vec!["e", "f", "g", "a"]; let a2 = d2.into_iter().collect::>(); - let cmp = build_compare(&a1, &a2).unwrap(); + let cmp = make_comparator(&a1, &a2, SortOptions::default()).unwrap(); assert_eq!(Ordering::Less, cmp(0, 0)); assert_eq!(Ordering::Equal, cmp(0, 3)); @@ -347,7 +581,7 @@ pub mod tests { let keys = Int8Array::from_iter_values([0, 1, 1, 3]); let array2 = DictionaryArray::new(keys, Arc::new(values)); - let cmp = build_compare(&array1, &array2).unwrap(); + let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap(); assert_eq!(Ordering::Less, cmp(0, 0)); assert_eq!(Ordering::Less, cmp(0, 3)); @@ -366,7 +600,7 @@ pub mod tests { let keys = Int8Array::from_iter_values([0, 1, 1, 3]); let array2 = DictionaryArray::new(keys, Arc::new(values)); - let cmp = build_compare(&array1, &array2).unwrap(); + let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap(); assert_eq!(Ordering::Less, cmp(0, 0)); assert_eq!(Ordering::Less, cmp(0, 3)); @@ -385,7 +619,7 @@ pub mod tests { let keys = Int8Array::from_iter_values([0, 1, 1, 3]); let array2 = DictionaryArray::new(keys, Arc::new(values)); - let cmp = build_compare(&array1, &array2).unwrap(); + let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap(); assert_eq!(Ordering::Less, cmp(0, 0)); assert_eq!(Ordering::Less, cmp(0, 3)); @@ -408,7 +642,7 @@ pub mod tests { let keys = Int8Array::from_iter_values([0, 1, 1, 3]); let array2 = DictionaryArray::new(keys, Arc::new(values)); - let cmp = build_compare(&array1, &array2).unwrap(); + let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap(); assert_eq!(Ordering::Less, cmp(0, 0)); // v1 vs v3 assert_eq!(Ordering::Equal, cmp(0, 3)); // v1 vs v1 @@ -427,7 +661,7 @@ pub mod tests { let keys = Int8Array::from_iter_values([0, 1, 1, 3]); let array2 = DictionaryArray::new(keys, Arc::new(values)); - let cmp = build_compare(&array1, &array2).unwrap(); + let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap(); assert_eq!(Ordering::Less, cmp(0, 0)); assert_eq!(Ordering::Less, cmp(0, 3)); @@ -446,7 +680,7 @@ pub mod tests { let keys = Int8Array::from_iter_values([0, 1, 1, 3]); let array2 = DictionaryArray::new(keys, Arc::new(values)); - let cmp = build_compare(&array1, &array2).unwrap(); + let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap(); assert_eq!(Ordering::Less, cmp(0, 0)); assert_eq!(Ordering::Less, cmp(0, 3)); @@ -475,7 +709,7 @@ pub mod tests { let keys = Int8Array::from_iter_values([0, 1, 1, 3]); let array2 = DictionaryArray::new(keys, Arc::new(values)); - let cmp = build_compare(&array1, &array2).unwrap(); + let cmp = make_comparator(&array1, &array2, SortOptions::default()).unwrap(); assert_eq!(Ordering::Less, cmp(0, 0)); assert_eq!(Ordering::Less, cmp(0, 3)); @@ -487,7 +721,7 @@ pub mod tests { fn test_bytes_impl() { let offsets = OffsetBuffer::from_lengths([3, 3, 1]); let a = GenericByteArray::::new(offsets, b"abcdefa".into(), None); - let cmp = build_compare(&a, &a).unwrap(); + let cmp = make_comparator(&a, &a, SortOptions::default()).unwrap(); assert_eq!(Ordering::Less, cmp(0, 1)); assert_eq!(Ordering::Greater, cmp(0, 2)); diff --git a/arrow-ord/src/sort.rs b/arrow-ord/src/sort.rs index fe3a1f86ac00..8ae87787d283 100644 --- a/arrow-ord/src/sort.rs +++ b/arrow-ord/src/sort.rs @@ -17,13 +17,13 @@ //! Defines sort kernel for `ArrayRef` -use crate::ord::{build_compare, DynComparator}; +use crate::ord::{make_comparator, DynComparator}; use arrow_array::builder::BufferBuilder; use arrow_array::cast::*; use arrow_array::types::*; use arrow_array::*; +use arrow_buffer::ArrowNativeType; use arrow_buffer::BooleanBufferBuilder; -use arrow_buffer::{ArrowNativeType, NullBuffer}; use arrow_data::ArrayDataBuilder; use arrow_schema::{ArrowError, DataType}; use arrow_select::take::take; @@ -704,60 +704,21 @@ where } } -type LexicographicalCompareItem = ( - Option, // nulls - DynComparator, // comparator - SortOptions, // sort_option -); - /// A lexicographical comparator that wraps given array data (columns) and can lexicographically compare data /// at given two indices. The lifetime is the same at the data wrapped. pub struct LexicographicalComparator { - compare_items: Vec, + compare_items: Vec, } impl LexicographicalComparator { /// lexicographically compare values at the wrapped columns with given indices. pub fn compare(&self, a_idx: usize, b_idx: usize) -> Ordering { - for (nulls, comparator, sort_option) in &self.compare_items { - let (lhs_valid, rhs_valid) = match nulls { - Some(n) => (n.is_valid(a_idx), n.is_valid(b_idx)), - None => (true, true), - }; - - match (lhs_valid, rhs_valid) { - (true, true) => { - match (comparator)(a_idx, b_idx) { - // equal, move on to next column - Ordering::Equal => continue, - order => { - if sort_option.descending { - return order.reverse(); - } else { - return order; - } - } - } - } - (false, true) => { - return if sort_option.nulls_first { - Ordering::Less - } else { - Ordering::Greater - }; - } - (true, false) => { - return if sort_option.nulls_first { - Ordering::Greater - } else { - Ordering::Less - }; - } - // equal, move on to next column - (false, false) => continue, + for comparator in &self.compare_items { + match comparator(a_idx, b_idx) { + Ordering::Equal => continue, + r => return r, } } - Ordering::Equal } @@ -766,61 +727,16 @@ impl LexicographicalComparator { pub fn try_new(columns: &[SortColumn]) -> Result { let compare_items = columns .iter() - .map(Self::build_compare_item) + .map(|c| { + make_comparator( + c.values.as_ref(), + c.values.as_ref(), + c.options.unwrap_or_default(), + ) + }) .collect::, ArrowError>>()?; Ok(LexicographicalComparator { compare_items }) } - - fn build_compare_item(column: &SortColumn) -> Result { - let values = column.values.as_ref(); - let options = column.options.unwrap_or_default(); - let comparator = match values.data_type() { - DataType::List(_) => Self::build_list_compare(values.as_list::(), options)?, - DataType::LargeList(_) => Self::build_list_compare(values.as_list::(), options)?, - DataType::FixedSizeList(_, _) => { - Self::build_fixed_size_list_compare(values.as_fixed_size_list(), options)? - } - _ => build_compare(values, values)?, - }; - Ok((values.logical_nulls(), comparator, options)) - } - - fn build_list_compare( - array: &GenericListArray, - options: SortOptions, - ) -> Result { - let rank = child_rank(array.values().as_ref(), options)?; - let offsets = array.offsets().clone(); - let cmp = Box::new(move |i: usize, j: usize| { - macro_rules! nth_value { - ($INDEX:expr) => {{ - let end = offsets[$INDEX + 1].as_usize(); - let start = offsets[$INDEX].as_usize(); - &rank[start..end] - }}; - } - Ord::cmp(nth_value!(i), nth_value!(j)) - }); - Ok(cmp) - } - - fn build_fixed_size_list_compare( - array: &FixedSizeListArray, - options: SortOptions, - ) -> Result { - let rank = child_rank(array.values().as_ref(), options)?; - let size = array.value_length() as usize; - let cmp = Box::new(move |i: usize, j: usize| { - macro_rules! nth_value { - ($INDEX:expr) => {{ - let start = $INDEX * size; - &rank[start..start + size] - }}; - } - Ord::cmp(nth_value!(i), nth_value!(j)) - }); - Ok(cmp) - } } #[cfg(test)] @@ -829,7 +745,7 @@ mod tests { use arrow_array::builder::{ FixedSizeListBuilder, Int64Builder, ListBuilder, PrimitiveRunBuilder, }; - use arrow_buffer::i256; + use arrow_buffer::{i256, NullBuffer}; use half::f16; use rand::rngs::StdRng; use rand::{Rng, RngCore, SeedableRng};