From 714f124618c500e38d3198a40cb51514529a0184 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 10 Jun 2021 08:51:05 -0400 Subject: [PATCH] refactor lexico sort (#423) (#442) Co-authored-by: Jiayu Liu --- arrow/src/compute/kernels/sort.rs | 113 ++++++++++++++++++------------ 1 file changed, 69 insertions(+), 44 deletions(-) diff --git a/arrow/src/compute/kernels/sort.rs b/arrow/src/compute/kernels/sort.rs index 76b22fa4aed9..2b7ad8dc68f8 100644 --- a/arrow/src/compute/kernels/sort.rs +++ b/arrow/src/compute/kernels/sort.rs @@ -17,14 +17,12 @@ //! Defines sort kernel for `ArrayRef` -use std::cmp::Ordering; - use crate::array::*; use crate::buffer::MutableBuffer; use crate::compute::take; use crate::datatypes::*; use crate::error::{ArrowError, Result}; - +use std::cmp::Ordering; use TimeUnit::*; /// Sort the `ArrayRef` using `SortOptions`. @@ -835,26 +833,55 @@ pub fn lexsort_to_indices( )); }; - // map to data and DynComparator - let flat_columns = columns - .iter() - .map( - |column| -> Result<(&ArrayData, DynComparator, SortOptions)> { - // flatten and convert build comparators - // use ArrayData for is_valid checks later to avoid dynamic call - let values = column.values.as_ref(); - let data = values.data_ref(); - Ok(( - data, - build_compare(values, values)?, - column.options.unwrap_or_default(), - )) - }, - ) - .collect::>>()?; + let mut value_indices = (0..row_count).collect::>(); + let mut len = value_indices.len(); + + if let Some(limit) = limit { + len = limit.min(len); + } + + let lexicographical_comparator = LexicographicalComparator::try_new(columns)?; + sort_by(&mut value_indices, len, |a, b| { + lexicographical_comparator.compare(a, b) + }); + + Ok(UInt32Array::from( + (&value_indices)[0..len] + .iter() + .map(|i| *i as u32) + .collect::>(), + )) +} + +/// It's unstable_sort, may not preserve the order of equal elements +pub fn partial_sort(v: &mut [T], limit: usize, mut is_less: F) +where + F: FnMut(&T, &T) -> Ordering, +{ + let (before, _mid, _after) = v.select_nth_unstable_by(limit, &mut is_less); + before.sort_unstable_by(is_less); +} + +type LexicographicalCompareItem<'a> = ( + &'a ArrayData, // data + Box Ordering + 'a>, // comparator + SortOptions, // sort_option +); + +/// A lexicographical comparator that wraps given array data (columns) and can lexicographically compare data +/// at given two indices. The lifetime is the same at the data wrapped. +pub(super) struct LexicographicalComparator<'a> { + compare_items: Vec>, +} - let lex_comparator = |a_idx: &usize, b_idx: &usize| -> Ordering { - for (data, comparator, sort_option) in flat_columns.iter() { +impl LexicographicalComparator<'_> { + /// lexicographically compare values at the wrapped columns with given indices. + pub(super) fn compare<'a, 'b>( + &'a self, + a_idx: &'b usize, + b_idx: &'b usize, + ) -> Ordering { + for (data, comparator, sort_option) in &self.compare_items { match (data.is_valid(*a_idx), data.is_valid(*b_idx)) { (true, true) => { match (comparator)(*a_idx, *b_idx) { @@ -889,31 +916,29 @@ pub fn lexsort_to_indices( } Ordering::Equal - }; - - let mut value_indices = (0..row_count).collect::>(); - let mut len = value_indices.len(); - - if let Some(limit) = limit { - len = limit.min(len); } - sort_by(&mut value_indices, len, lex_comparator); - Ok(UInt32Array::from( - (&value_indices)[0..len] + /// Create a new lex comparator that will wrap the given sort columns and give comparison + /// results with two indices. + pub(super) fn try_new( + columns: &[SortColumn], + ) -> Result> { + let compare_items = columns .iter() - .map(|i| *i as u32) - .collect::>(), - )) -} - -/// It's unstable_sort, may not preserve the order of equal elements -pub fn partial_sort(v: &mut [T], limit: usize, mut is_less: F) -where - F: FnMut(&T, &T) -> Ordering, -{ - let (before, _mid, _after) = v.select_nth_unstable_by(limit, &mut is_less); - before.sort_unstable_by(is_less); + .map(|column| { + // flatten and convert build comparators + // use ArrayData for is_valid checks later to avoid dynamic call + let values = column.values.as_ref(); + let data = values.data_ref(); + Ok(( + data, + build_compare(values, values)?, + column.options.unwrap_or_default(), + )) + }) + .collect::>>()?; + Ok(LexicographicalComparator { compare_items }) + } } #[cfg(test)]