Skip to content

Commit

Permalink
Fix primitive sort when input contains more nulls than the given sort…
Browse files Browse the repository at this point in the history
… limit (#954)
  • Loading branch information
jhorstmann authored Nov 18, 2021
1 parent 007fb58 commit 02f3ec8
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions arrow/src/compute/kernels/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1032,13 +1032,11 @@ fn sort_valids<T, U>(
) where
T: ?Sized + Copy,
{
let nulls_len = nulls.len();
let valids_len = valids.len();
if !descending {
sort_unstable_by(valids, len.saturating_sub(nulls_len), |a, b| cmp(a.1, b.1));
sort_unstable_by(valids, len.min(valids_len), |a, b| cmp(a.1, b.1));
} else {
sort_unstable_by(valids, len.saturating_sub(nulls_len), |a, b| {
cmp(a.1, b.1).reverse()
});
sort_unstable_by(valids, len.min(valids_len), |a, b| cmp(a.1, b.1).reverse());
// reverse to keep a stable ordering
nulls.reverse();
}
Expand All @@ -1050,13 +1048,13 @@ fn sort_valids_array<T>(
nulls: &mut [T],
len: usize,
) {
let nulls_len = nulls.len();
let valids_len = valids.len();
if !descending {
sort_unstable_by(valids, len.saturating_sub(nulls_len), |a, b| {
sort_unstable_by(valids, len.min(valids_len), |a, b| {
cmp_array(a.1.as_ref(), b.1.as_ref())
});
} else {
sort_unstable_by(valids, len.saturating_sub(nulls_len), |a, b| {
sort_unstable_by(valids, len.min(valids_len), |a, b| {
cmp_array(a.1.as_ref(), b.1.as_ref()).reverse()
});
// reverse to keep a stable ordering
Expand Down Expand Up @@ -1555,6 +1553,19 @@ mod tests {
);
}

#[test]
fn test_sort_to_indices_primitive_more_nulls_than_limit() {
test_sort_to_indices_primitive_arrays::<Int32Type>(
vec![None, None, Some(3), None, Some(1), None, Some(2)],
Some(SortOptions {
descending: false,
nulls_first: false,
}),
Some(2),
vec![4, 6],
);
}

#[test]
fn test_sort_boolean() {
// boolean
Expand Down

0 comments on commit 02f3ec8

Please sign in to comment.