diff --git a/library/core/src/slice/sort.rs b/library/core/src/slice/sort.rs index 87f77b7f21d6..552f90452390 100644 --- a/library/core/src/slice/sort.rs +++ b/library/core/src/slice/sort.rs @@ -6,9 +6,10 @@ //! Unstable sorting is compatible with libcore because it doesn't allocate memory, unlike our //! stable sorting implementation. -use crate::cmp; -use crate::mem::{self, MaybeUninit, SizedTypeProperties}; -use crate::ptr; +use core::cmp; +use core::intrinsics; +use core::mem::{self, MaybeUninit, SizedTypeProperties}; +use core::ptr; /// When dropped, copies from `src` into `dest`. struct CopyOnDrop { @@ -27,98 +28,6 @@ impl Drop for CopyOnDrop { } } -/// Shifts the first element to the right until it encounters a greater or equal element. -fn shift_head(v: &mut [T], is_less: &mut F) -where - F: FnMut(&T, &T) -> bool, -{ - let len = v.len(); - // SAFETY: The unsafe operations below involves indexing without a bounds check (by offsetting a - // pointer) and copying memory (`ptr::copy_nonoverlapping`). - // - // a. Indexing: - // 1. We checked the size of the array to >=2. - // 2. All the indexing that we will do is always between {0 <= index < len} at most. - // - // b. Memory copying - // 1. We are obtaining pointers to references which are guaranteed to be valid. - // 2. They cannot overlap because we obtain pointers to difference indices of the slice. - // Namely, `i` and `i-1`. - // 3. If the slice is properly aligned, the elements are properly aligned. - // It is the caller's responsibility to make sure the slice is properly aligned. - // - // See comments below for further detail. - unsafe { - // If the first two elements are out-of-order... - if len >= 2 && is_less(v.get_unchecked(1), v.get_unchecked(0)) { - // Read the first element into a stack-allocated variable. If a following comparison - // operation panics, `hole` will get dropped and automatically write the element back - // into the slice. - let tmp = mem::ManuallyDrop::new(ptr::read(v.get_unchecked(0))); - let v = v.as_mut_ptr(); - let mut hole = CopyOnDrop { src: &*tmp, dest: v.add(1) }; - ptr::copy_nonoverlapping(v.add(1), v.add(0), 1); - - for i in 2..len { - if !is_less(&*v.add(i), &*tmp) { - break; - } - - // Move `i`-th element one place to the left, thus shifting the hole to the right. - ptr::copy_nonoverlapping(v.add(i), v.add(i - 1), 1); - hole.dest = v.add(i); - } - // `hole` gets dropped and thus copies `tmp` into the remaining hole in `v`. - } - } -} - -/// Shifts the last element to the left until it encounters a smaller or equal element. -fn shift_tail(v: &mut [T], is_less: &mut F) -where - F: FnMut(&T, &T) -> bool, -{ - let len = v.len(); - // SAFETY: The unsafe operations below involves indexing without a bound check (by offsetting a - // pointer) and copying memory (`ptr::copy_nonoverlapping`). - // - // a. Indexing: - // 1. We checked the size of the array to >= 2. - // 2. All the indexing that we will do is always between `0 <= index < len-1` at most. - // - // b. Memory copying - // 1. We are obtaining pointers to references which are guaranteed to be valid. - // 2. They cannot overlap because we obtain pointers to difference indices of the slice. - // Namely, `i` and `i+1`. - // 3. If the slice is properly aligned, the elements are properly aligned. - // It is the caller's responsibility to make sure the slice is properly aligned. - // - // See comments below for further detail. - unsafe { - // If the last two elements are out-of-order... - if len >= 2 && is_less(v.get_unchecked(len - 1), v.get_unchecked(len - 2)) { - // Read the last element into a stack-allocated variable. If a following comparison - // operation panics, `hole` will get dropped and automatically write the element back - // into the slice. - let tmp = mem::ManuallyDrop::new(ptr::read(v.get_unchecked(len - 1))); - let v = v.as_mut_ptr(); - let mut hole = CopyOnDrop { src: &*tmp, dest: v.add(len - 2) }; - ptr::copy_nonoverlapping(v.add(len - 2), v.add(len - 1), 1); - - for i in (0..len - 2).rev() { - if !is_less(&*tmp, &*v.add(i)) { - break; - } - - // Move `i`-th element one place to the right, thus shifting the hole to the left. - ptr::copy_nonoverlapping(v.add(i), v.add(i + 1), 1); - hole.dest = v.add(i); - } - // `hole` gets dropped and thus copies `tmp` into the remaining hole in `v`. - } - } -} - /// Partially sorts a slice by shifting several out-of-order elements around. /// /// Returns `true` if the slice is sorted at the end. This function is *O*(*n*) worst-case. @@ -158,26 +67,35 @@ where // Swap the found pair of elements. This puts them in correct order. v.swap(i - 1, i); + if i >= 2 { + // SAFETY: We check the that the slice len is >= 2. + unsafe { + insert_tail(&mut v[..i], is_less); + } + } + // Shift the smaller element to the left. - shift_tail(&mut v[..i], is_less); + if i >= 2 { + // SAFETY: We check the that the slice len is >= 2. + unsafe { + insert_tail(&mut v[..i], is_less); + } + } + // Shift the greater element to the right. - shift_head(&mut v[i..], is_less); + if i < (len - 1) { + // SAFETY: We check the that the slice len is >= 2. + unsafe { + // shift_head(&mut v[i..], is_less); + insert_head(&mut v[i..], is_less); + } + } } // Didn't manage to sort the slice in the limited number of steps. false } -/// Sorts a slice using insertion sort, which is *O*(*n*^2) worst-case. -fn insertion_sort(v: &mut [T], is_less: &mut F) -where - F: FnMut(&T, &T) -> bool, -{ - for i in 1..v.len() { - shift_tail(&mut v[..i + 1], is_less); - } -} - /// Sorts `v` using heapsort, which guarantees *O*(*n* \* log(*n*)) worst-case. #[cold] #[unstable(feature = "sort_internals", reason = "internal to sort module", issue = "none")] @@ -326,8 +244,8 @@ where unsafe { // Branchless comparison. *end_l = i as u8; - end_l = end_l.add(!is_less(&*elem, pivot) as usize); - elem = elem.add(1); + end_l = end_l.offset(!is_less(&*elem, pivot) as isize); + elem = elem.offset(1); } } } @@ -352,9 +270,9 @@ where // Plus, `block_r` was asserted to be less than `BLOCK` and `elem` will therefore at most be pointing to the beginning of the slice. unsafe { // Branchless comparison. - elem = elem.sub(1); + elem = elem.offset(-1); *end_r = i as u8; - end_r = end_r.add(is_less(&*elem, pivot) as usize); + end_r = end_r.offset(is_less(&*elem, pivot) as isize); } } } @@ -365,12 +283,12 @@ where if count > 0 { macro_rules! left { () => { - l.add(usize::from(*start_l)) + l.offset(*start_l as isize) }; } macro_rules! right { () => { - r.sub(usize::from(*start_r) + 1) + r.offset(-(*start_r as isize) - 1) }; } @@ -398,16 +316,16 @@ where ptr::copy_nonoverlapping(right!(), left!(), 1); for _ in 1..count { - start_l = start_l.add(1); + start_l = start_l.offset(1); ptr::copy_nonoverlapping(left!(), right!(), 1); - start_r = start_r.add(1); + start_r = start_r.offset(1); ptr::copy_nonoverlapping(right!(), left!(), 1); } ptr::copy_nonoverlapping(&tmp, right!(), 1); mem::forget(tmp); - start_l = start_l.add(1); - start_r = start_r.add(1); + start_l = start_l.offset(1); + start_r = start_r.offset(1); } } @@ -420,7 +338,7 @@ where // safe. Otherwise, the debug assertions in the `is_done` case guarantee that // `width(l, r) == block_l + block_r`, namely, that the block sizes have been adjusted to account // for the smaller number of remaining elements. - l = unsafe { l.add(block_l) }; + l = unsafe { l.offset(block_l as isize) }; } if start_r == end_r { @@ -428,7 +346,7 @@ where // SAFETY: Same argument as [block-width-guarantee]. Either this is a full block `2*BLOCK`-wide, // or `block_r` has been adjusted for the last handful of elements. - r = unsafe { r.sub(block_r) }; + r = unsafe { r.offset(-(block_r as isize)) }; } if is_done { @@ -457,9 +375,9 @@ where // - `offsets_l` contains valid offsets into `v` collected during the partitioning of // the last block, so the `l.offset` calls are valid. unsafe { - end_l = end_l.sub(1); - ptr::swap(l.add(usize::from(*end_l)), r.sub(1)); - r = r.sub(1); + end_l = end_l.offset(-1); + ptr::swap(l.offset(*end_l as isize), r.offset(-1)); + r = r.offset(-1); } } width(v.as_mut_ptr(), r) @@ -470,9 +388,9 @@ where while start_r < end_r { // SAFETY: See the reasoning in [remaining-elements-safety]. unsafe { - end_r = end_r.sub(1); - ptr::swap(l, r.sub(usize::from(*end_r) + 1)); - l = l.add(1); + end_r = end_r.offset(-1); + ptr::swap(l, r.offset(-(*end_r as isize) - 1)); + l = l.offset(1); } } width(v.as_mut_ptr(), l) @@ -659,6 +577,12 @@ where let len = v.len(); + if len <= MAX_INSERTION { + // It's a logic bug if this get's called on slice that would be small-sorted. + debug_assert!(false); + return (10, false); + } + // Three indices near which we are going to choose a pivot. let mut a = len / 4 * 1; let mut b = len / 4 * 2; @@ -667,45 +591,46 @@ where // Counts the total number of swaps we are about to perform while sorting indices. let mut swaps = 0; - if len >= 8 { - // Swaps indices so that `v[a] <= v[b]`. - // SAFETY: `len >= 8` so there are at least two elements in the neighborhoods of - // `a`, `b` and `c`. This means the three calls to `sort_adjacent` result in - // corresponding calls to `sort3` with valid 3-item neighborhoods around each - // pointer, which in turn means the calls to `sort2` are done with valid - // references. Thus the `v.get_unchecked` calls are safe, as is the `ptr::swap` - // call. - let mut sort2 = |a: &mut usize, b: &mut usize| unsafe { - if is_less(v.get_unchecked(*b), v.get_unchecked(*a)) { - ptr::swap(a, b); - swaps += 1; - } - }; - - // Swaps indices so that `v[a] <= v[b] <= v[c]`. - let mut sort3 = |a: &mut usize, b: &mut usize, c: &mut usize| { - sort2(a, b); - sort2(b, c); - sort2(a, b); - }; + // Swaps indices so that `v[a] <= v[b]`. + // SAFETY: `len > 20` so there are at least two elements in the neighborhoods of + // `a`, `b` and `c`. This means the three calls to `sort_adjacent` result in + // corresponding calls to `sort3` with valid 3-item neighborhoods around each + // pointer, which in turn means the calls to `sort2` are done with valid + // references. Thus the `v.get_unchecked` calls are safe, as is the `ptr::swap` + // call. + let mut sort2_idx = |a: &mut usize, b: &mut usize| unsafe { + let should_swap = is_less(v.get_unchecked(*b), v.get_unchecked(*a)); + + // Generate branchless cmov code, it's not super important but reduces BHB and BTB pressure. + let tmp_idx = if should_swap { *a } else { *b }; + *a = if should_swap { *b } else { *a }; + *b = tmp_idx; + swaps += should_swap as usize; + }; - if len >= SHORTEST_MEDIAN_OF_MEDIANS { - // Finds the median of `v[a - 1], v[a], v[a + 1]` and stores the index into `a`. - let mut sort_adjacent = |a: &mut usize| { - let tmp = *a; - sort3(&mut (tmp - 1), a, &mut (tmp + 1)); - }; + // Swaps indices so that `v[a] <= v[b] <= v[c]`. + let mut sort3_idx = |a: &mut usize, b: &mut usize, c: &mut usize| { + sort2_idx(a, b); + sort2_idx(b, c); + sort2_idx(a, b); + }; - // Find medians in the neighborhoods of `a`, `b`, and `c`. - sort_adjacent(&mut a); - sort_adjacent(&mut b); - sort_adjacent(&mut c); - } + if len >= SHORTEST_MEDIAN_OF_MEDIANS { + // Finds the median of `v[a - 1], v[a], v[a + 1]` and stores the index into `a`. + let mut sort_adjacent = |a: &mut usize| { + let tmp = *a; + sort3_idx(&mut (tmp - 1), a, &mut (tmp + 1)); + }; - // Find the median among `a`, `b`, and `c`. - sort3(&mut a, &mut b, &mut c); + // Find medians in the neighborhoods of `a`, `b`, and `c`. + sort_adjacent(&mut a); + sort_adjacent(&mut b); + sort_adjacent(&mut c); } + // Find the median among `a`, `b`, and `c`. + sort3_idx(&mut a, &mut b, &mut c); + if swaps < MAX_SWAPS { (b, swaps == 0) } else { @@ -716,6 +641,9 @@ where } } +// Slices of up to this length get sorted using insertion sort. +const MAX_INSERTION: usize = 20; + /// Sorts `v` recursively. /// /// If the slice had a predecessor in the original array, it is specified as `pred`. @@ -726,9 +654,6 @@ fn recurse<'a, T, F>(mut v: &'a mut [T], is_less: &mut F, mut pred: Option<&'a T where F: FnMut(&T, &T) -> bool, { - // Slices of up to this length get sorted using insertion sort. - const MAX_INSERTION: usize = 20; - // True if the last partitioning was reasonably balanced. let mut was_balanced = true; // True if the last partitioning didn't shuffle elements (the slice was already partitioned). @@ -737,9 +662,9 @@ where loop { let len = v.len(); - // Very short slices get sorted using insertion sort. - if len <= MAX_INSERTION { - insertion_sort(v, is_less); + // println!("len: {len}"); + + if sort_small(v, is_less) { return; } @@ -807,13 +732,140 @@ where } } +/// Sorts `v` using strategies optimized for small sizes. +pub fn sort_small(v: &mut [T], is_less: &mut F) -> bool +where + F: FnMut(&T, &T) -> bool, +{ + let len = v.len(); + + const MAX_BRANCHLESS_SMALL_SORT: usize = 40; + + if len < 2 { + return true; + } + + if qualifies_for_branchless_sort::() && len <= MAX_BRANCHLESS_SMALL_SORT { + if len < 8 { + // For small sizes it's better to just sort. The worst case 7, will only go from 6 to 8 + // comparisons for already sorted inputs. + let start = if len >= 4 { + // SAFETY: We just checked the len. + unsafe { + sort4_optimal(&mut v[0..4], is_less); + } + 4 + } else { + 1 + }; + + insertion_sort_shift_left(v, start, is_less); + + return true; + } + + // Pattern analyze to minimize comparison count for already sorted or reversed inputs. + // For larger inputs pdqsort pattern analysis will be used. + + let mut start = len - 1; + if start > 0 { + start -= 1; + unsafe { + if is_less(v.get_unchecked(start + 1), v.get_unchecked(start)) { + while start > 0 && is_less(v.get_unchecked(start), v.get_unchecked(start - 1)) { + start -= 1; + } + v[start..len].reverse(); + } else { + while start > 0 && !is_less(v.get_unchecked(start), v.get_unchecked(start - 1)) + { + start -= 1; + } + } + } + } + + debug_assert!(start < len); + + let already_sorted = len - start; + + if already_sorted <= 6 { + // SAFETY: We check the len. + unsafe { + match len { + 8..=15 => { + sort8_plus(v, is_less); + } + 16..=31 => { + sort16_plus(v, is_less); + } + 32..=40 => { + sort32_plus(v, is_less); + } + _ => { + unreachable!() + } + } + } + } else { + // Potentially highly or fully sorted. We know that already_sorted >= 7. and len >= 8. + // That leaves the range of start <= 33. + debug_assert!(start <= 33); + + if start == 0 { + return true; + } else if start <= 3 { + insertion_sort_shift_right(v, start, is_less); + return true; + } + + match start { + 4..=7 => { + // SAFETY: We just checked start >= 4. + unsafe { + sort4_plus(&mut v[0..start], is_less); + } + } + 8..=15 => { + // SAFETY: We just checked start >= 8. + unsafe { + sort8_plus(&mut v[0..start], is_less); + } + } + 16..=33 => { + // SAFETY: We just checked start >= 16. + unsafe { + sort16_plus(&mut v[0..start], is_less); + } + } + _ => unreachable!(), + } + + // The longest possible shortest side is len == 40, start == 20 -> 20. + let mut swap = mem::MaybeUninit::<[T; 20]>::uninit(); + let swap_ptr = swap.as_mut_ptr() as *mut T; + + // SAFETY: swap is long enough and both sides are len >= 1. + unsafe { + merge(v, start, swap_ptr, is_less); + } + } + return true; + } else if len <= MAX_INSERTION { + insertion_sort_shift_left(v, 1, is_less); + return true; + } + + false +} + /// Sorts `v` using pattern-defeating quicksort, which is *O*(*n* \* log(*n*)) worst-case. pub fn quicksort(v: &mut [T], mut is_less: F) where F: FnMut(&T, &T) -> bool, { // Sorting has no meaningful behavior on zero-sized types. - if T::IS_ZST { + if mem::size_of::() == 0 { return; } @@ -823,6 +875,609 @@ where recurse(v, &mut is_less, None, limit); } +// --- Insertion sorts --- + +// TODO unified sort module. + +// When dropped, copies from `src` into `dest`. +struct InsertionHole { + src: *const T, + dest: *mut T, +} + +impl Drop for InsertionHole { + fn drop(&mut self) { + unsafe { + ptr::copy_nonoverlapping(self.src, self.dest, 1); + } + } +} + +/// Inserts `v[v.len() - 1]` into pre-sorted sequence `v[..v.len() - 1]` so that whole `v[..]` +/// becomes sorted. +unsafe fn insert_tail(v: &mut [T], is_less: &mut F) +where + F: FnMut(&T, &T) -> bool, +{ + debug_assert!(v.len() >= 2); + + let arr_ptr = v.as_mut_ptr(); + let i = v.len() - 1; + + // SAFETY: caller must ensure v is at least len 2. + unsafe { + // See insert_head which talks about why this approach is beneficial. + let i_ptr = arr_ptr.add(i); + + // It's important that we use i_ptr here. If this check is positive and we continue, + // We want to make sure that no other copy of the value was seen by is_less. + // Otherwise we would have to copy it back. + if is_less(&*i_ptr, &*i_ptr.sub(1)) { + // It's important, that we use tmp for comparison from now on. As it is the value that + // will be copied back. And notionally we could have created a divergence if we copy + // back the wrong value. + let tmp = mem::ManuallyDrop::new(ptr::read(i_ptr)); + // Intermediate state of the insertion process is always tracked by `hole`, which + // serves two purposes: + // 1. Protects integrity of `v` from panics in `is_less`. + // 2. Fills the remaining hole in `v` in the end. + // + // Panic safety: + // + // If `is_less` panics at any point during the process, `hole` will get dropped and + // fill the hole in `v` with `tmp`, thus ensuring that `v` still holds every object it + // initially held exactly once. + let mut hole = InsertionHole { src: &*tmp, dest: i_ptr.sub(1) }; + ptr::copy_nonoverlapping(hole.dest, i_ptr, 1); + + // SAFETY: We know i is at least 1. + for j in (0..(i - 1)).rev() { + let j_ptr = arr_ptr.add(j); + if !is_less(&*tmp, &*j_ptr) { + break; + } + + ptr::copy_nonoverlapping(j_ptr, hole.dest, 1); + hole.dest = j_ptr; + } + // `hole` gets dropped and thus copies `tmp` into the remaining hole in `v`. + } + } +} + +/// Inserts `v[0]` into pre-sorted sequence `v[1..]` so that whole `v[..]` becomes sorted. +/// +/// This is the integral subroutine of insertion sort. +unsafe fn insert_head(v: &mut [T], is_less: &mut F) +where + F: FnMut(&T, &T) -> bool, +{ + debug_assert!(v.len() >= 2); + + unsafe { + if is_less(v.get_unchecked(1), v.get_unchecked(0)) { + let arr_ptr = v.as_mut_ptr(); + + // There are three ways to implement insertion here: + // + // 1. Swap adjacent elements until the first one gets to its final destination. + // However, this way we copy data around more than is necessary. If elements are big + // structures (costly to copy), this method will be slow. + // + // 2. Iterate until the right place for the first element is found. Then shift the + // elements succeeding it to make room for it and finally place it into the + // remaining hole. This is a good method. + // + // 3. Copy the first element into a temporary variable. Iterate until the right place + // for it is found. As we go along, copy every traversed element into the slot + // preceding it. Finally, copy data from the temporary variable into the remaining + // hole. This method is very good. Benchmarks demonstrated slightly better + // performance than with the 2nd method. + // + // All methods were benchmarked, and the 3rd showed best results. So we chose that one. + let tmp = mem::ManuallyDrop::new(ptr::read(arr_ptr)); + + // Intermediate state of the insertion process is always tracked by `hole`, which + // serves two purposes: + // 1. Protects integrity of `v` from panics in `is_less`. + // 2. Fills the remaining hole in `v` in the end. + // + // Panic safety: + // + // If `is_less` panics at any point during the process, `hole` will get dropped and + // fill the hole in `v` with `tmp`, thus ensuring that `v` still holds every object it + // initially held exactly once. + let mut hole = InsertionHole { src: &*tmp, dest: arr_ptr.add(1) }; + ptr::copy_nonoverlapping(arr_ptr.add(1), arr_ptr.add(0), 1); + + for i in 2..v.len() { + if !is_less(&v.get_unchecked(i), &*tmp) { + break; + } + ptr::copy_nonoverlapping(arr_ptr.add(i), arr_ptr.add(i - 1), 1); + hole.dest = arr_ptr.add(i); + } + // `hole` gets dropped and thus copies `tmp` into the remaining hole in `v`. + } + } +} + +/// Sort `v` assuming `v[..offset]` is already sorted. +/// +/// Never inline this function to avoid code bloat. It still optimizes nicely and has practically no +/// performance impact. Even improving performance in some cases. +#[inline(never)] +fn insertion_sort_shift_left(v: &mut [T], offset: usize, is_less: &mut F) +where + F: FnMut(&T, &T) -> bool, +{ + let len = v.len(); + + // This is a logic but not a safety bug. + debug_assert!(offset != 0 && offset <= len); + + if intrinsics::unlikely(((len < 2) as u8 + (offset == 0) as u8) != 0) { + return; + } + + // Shift each element of the unsorted region v[i..] as far left as is needed to make v sorted. + for i in offset..len { + // SAFETY: we tested that len >= 2. + unsafe { + // Maybe use insert_head here and avoid additional code. + insert_tail(&mut v[..=i], is_less); + } + } +} + +/// Sort `v` assuming `v[offset..]` is already sorted. +/// +/// Never inline this function to avoid code bloat. It still optimizes nicely and has practically no +/// performance impact. Even improving performance in some cases. +#[inline(never)] +fn insertion_sort_shift_right(v: &mut [T], offset: usize, is_less: &mut F) +where + F: FnMut(&T, &T) -> bool, +{ + let len = v.len(); + + // This is a logic but not a safety bug. + debug_assert!(offset != 0 && offset <= len); + + if intrinsics::unlikely(((len < 2) as u8 + (offset == 0) as u8) != 0) { + return; + } + + // Shift each element of the unsorted region v[..i] as far left as is needed to make v sorted. + for i in (0..offset).rev() { + // We ensured that the slice length is always at least 2 long. + // We know that start_found will be at least one less than end, + // and the range is exclusive. Which gives us i always <= (end - 2). + unsafe { + insert_head(&mut v[i..len], is_less); + } + } +} + +/// Merges non-decreasing runs `v[..mid]` and `v[mid..]` using `buf` as temporary storage, and +/// stores the result into `v[..]`. +/// +/// # Safety +/// +/// The two slices must be non-empty and `mid` must be in bounds. Buffer `buf` must be long enough +/// to hold a copy of the shorter slice. Also, `T` must not be a zero-sized type. +/// +/// Never inline this function to avoid code bloat. It still optimizes nicely and has practically no +/// performance impact. +#[inline(never)] +#[cfg(not(no_global_oom_handling))] +unsafe fn merge(v: &mut [T], mid: usize, buf: *mut T, is_less: &mut F) +where + F: FnMut(&T, &T) -> bool, +{ + let len = v.len(); + let arr_ptr = v.as_mut_ptr(); + let (v_mid, v_end) = unsafe { (arr_ptr.add(mid), arr_ptr.add(len)) }; + + // The merge process first copies the shorter run into `buf`. Then it traces the newly copied + // run and the longer run forwards (or backwards), comparing their next unconsumed elements and + // copying the lesser (or greater) one into `v`. + // + // As soon as the shorter run is fully consumed, the process is done. If the longer run gets + // consumed first, then we must copy whatever is left of the shorter run into the remaining + // hole in `v`. + // + // Intermediate state of the process is always tracked by `hole`, which serves two purposes: + // 1. Protects integrity of `v` from panics in `is_less`. + // 2. Fills the remaining hole in `v` if the longer run gets consumed first. + // + // Panic safety: + // + // If `is_less` panics at any point during the process, `hole` will get dropped and fill the + // hole in `v` with the unconsumed range in `buf`, thus ensuring that `v` still holds every + // object it initially held exactly once. + let mut hole; + + if mid <= len - mid { + // The left run is shorter. + unsafe { + ptr::copy_nonoverlapping(arr_ptr, buf, mid); + hole = MergeHole { start: buf, end: buf.add(mid), dest: arr_ptr }; + } + + // Initially, these pointers point to the beginnings of their arrays. + let left = &mut hole.start; + let mut right = v_mid; + let out = &mut hole.dest; + + while *left < hole.end && right < v_end { + // Consume the lesser side. + // If equal, prefer the left run to maintain stability. + unsafe { + let to_copy = if is_less(&*right, &**left) { + get_and_increment(&mut right) + } else { + get_and_increment(left) + }; + ptr::copy_nonoverlapping(to_copy, get_and_increment(out), 1); + } + } + } else { + // The right run is shorter. + unsafe { + ptr::copy_nonoverlapping(v_mid, buf, len - mid); + hole = MergeHole { start: buf, end: buf.add(len - mid), dest: v_mid }; + } + + // Initially, these pointers point past the ends of their arrays. + let left = &mut hole.dest; + let right = &mut hole.end; + let mut out = v_end; + + while arr_ptr < *left && buf < *right { + // Consume the greater side. + // If equal, prefer the right run to maintain stability. + unsafe { + let to_copy = if is_less(&*right.offset(-1), &*left.offset(-1)) { + decrement_and_get(left) + } else { + decrement_and_get(right) + }; + ptr::copy_nonoverlapping(to_copy, decrement_and_get(&mut out), 1); + } + } + } + // Finally, `hole` gets dropped. If the shorter run was not fully consumed, whatever remains of + // it will now be copied into the hole in `v`. + + unsafe fn get_and_increment(ptr: &mut *mut T) -> *mut T { + let old = *ptr; + *ptr = unsafe { ptr.offset(1) }; + old + } + + unsafe fn decrement_and_get(ptr: &mut *mut T) -> *mut T { + *ptr = unsafe { ptr.offset(-1) }; + *ptr + } + + // When dropped, copies the range `start..end` into `dest..`. + struct MergeHole { + start: *mut T, + end: *mut T, + dest: *mut T, + } + + impl Drop for MergeHole { + fn drop(&mut self) { + // `T` is not a zero-sized type, and these are pointers into a slice's elements. + unsafe { + let len = self.end.sub_ptr(self.start); + ptr::copy_nonoverlapping(self.start, self.dest, len); + } + } + } +} + +// --- Branchless sorting (less branches not zero) --- + +#[inline] +fn qualifies_for_branchless_sort() -> bool { + // This is a heuristic, and as such it will guess wrong from time to time. The two parts broken + // down: + // + // - Type size: Large types are more expensive to move and the time won avoiding branches can be + // offset by the increased cost of moving the values. + // + // In contrast to stable sort, using sorting networks here, allows to do fewer comparisons. + mem::size_of::() <= mem::size_of::<[usize; 4]>() +} + +/// Swap two values in array pointed to by a_ptr and b_ptr if b is less than a. +#[inline] +unsafe fn branchless_swap(a_ptr: *mut T, b_ptr: *mut T, should_swap: bool) { + // This is a branchless version of swap if. + // The equivalent code with a branch would be: + // + // if should_swap { + // ptr::swap_nonoverlapping(a_ptr, b_ptr, 1); + // } + + // Give ourselves some scratch space to work with. + // We do not have to worry about drops: `MaybeUninit` does nothing when dropped. + let mut tmp = mem::MaybeUninit::::uninit(); + + // The goal is to generate cmov instructions here. + let a_swap_ptr = if should_swap { b_ptr } else { a_ptr }; + let b_swap_ptr = if should_swap { a_ptr } else { b_ptr }; + + // SAFETY: the caller must guarantee that `a_ptr` and `b_ptr` are valid for writes + // and properly aligned, and part of the same allocation, and do not alias. + unsafe { + ptr::copy_nonoverlapping(b_swap_ptr, tmp.as_mut_ptr(), 1); + ptr::copy(a_swap_ptr, a_ptr, 1); + ptr::copy_nonoverlapping(tmp.as_ptr(), b_ptr, 1); + } +} + +/// Swap two values in array pointed to by a_ptr and b_ptr if b is less than a. +#[inline] +unsafe fn swap_if_less(arr_ptr: *mut T, a: usize, b: usize, is_less: &mut F) +where + F: FnMut(&T, &T) -> bool, +{ + // SAFETY: the caller must guarantee that `a` and `b` each added to `arr_ptr` yield valid + // pointers into `arr_ptr`. and properly aligned, and part of the same allocation, and do not + // alias. `a` and `b` must be different numbers. + unsafe { + debug_assert!(a != b); + + let a_ptr = arr_ptr.add(a); + let b_ptr = arr_ptr.add(b); + + // PANIC SAFETY: if is_less panics, no scratch memory was created and the slice should still be + // in a well defined state, without duplicates. + + // Important to only swap if it is more and not if it is equal. is_less should return false for + // equal, so we don't swap. + let should_swap = is_less(&*b_ptr, &*a_ptr); + + branchless_swap(a_ptr, b_ptr, should_swap); + } +} + +// Never inline this function to avoid code bloat. It still optimizes nicely and has practically no +// performance impact. +#[inline(never)] +unsafe fn sort4_optimal(v: &mut [T], is_less: &mut F) +where + F: FnMut(&T, &T) -> bool, +{ + // SAFETY: caller must ensure v.len() >= 4. + unsafe { + debug_assert!(v.len() == 4); + + let arr_ptr = v.as_mut_ptr(); + + // Optimal sorting network see: + // https://bertdobbelaere.github.io/sorting_networks.html. + + swap_if_less(arr_ptr, 0, 2, is_less); + swap_if_less(arr_ptr, 1, 3, is_less); + swap_if_less(arr_ptr, 0, 1, is_less); + swap_if_less(arr_ptr, 2, 3, is_less); + swap_if_less(arr_ptr, 1, 2, is_less); + } +} + +// Never inline this function to avoid code bloat. It still optimizes nicely and has practically no +// performance impact. +#[inline(never)] +unsafe fn sort8_optimal(v: &mut [T], is_less: &mut F) +where + F: FnMut(&T, &T) -> bool, +{ + // SAFETY: caller must ensure v.len() >= 8. + unsafe { + debug_assert!(v.len() == 8); + + let arr_ptr = v.as_mut_ptr(); + + // Optimal sorting network see: + // https://bertdobbelaere.github.io/sorting_networks.html. + + swap_if_less(arr_ptr, 0, 2, is_less); + swap_if_less(arr_ptr, 1, 3, is_less); + swap_if_less(arr_ptr, 4, 6, is_less); + swap_if_less(arr_ptr, 5, 7, is_less); + swap_if_less(arr_ptr, 0, 4, is_less); + swap_if_less(arr_ptr, 1, 5, is_less); + swap_if_less(arr_ptr, 2, 6, is_less); + swap_if_less(arr_ptr, 3, 7, is_less); + swap_if_less(arr_ptr, 0, 1, is_less); + swap_if_less(arr_ptr, 2, 3, is_less); + swap_if_less(arr_ptr, 4, 5, is_less); + swap_if_less(arr_ptr, 6, 7, is_less); + swap_if_less(arr_ptr, 2, 4, is_less); + swap_if_less(arr_ptr, 3, 5, is_less); + swap_if_less(arr_ptr, 1, 4, is_less); + swap_if_less(arr_ptr, 3, 6, is_less); + swap_if_less(arr_ptr, 1, 2, is_less); + swap_if_less(arr_ptr, 3, 4, is_less); + swap_if_less(arr_ptr, 5, 6, is_less); + } +} + +// Never inline this function to avoid code bloat. It still optimizes nicely and has practically no +// performance impact. +#[inline(never)] +unsafe fn sort16_optimal(v: &mut [T], is_less: &mut F) +where + F: FnMut(&T, &T) -> bool, +{ + // SAFETY: caller must ensure v.len() >= 16. + unsafe { + debug_assert!(v.len() == 16); + + let arr_ptr = v.as_mut_ptr(); + + // Optimal sorting network see: + // https://bertdobbelaere.github.io/sorting_networks.html#N16L60D10 + + swap_if_less(arr_ptr, 0, 13, is_less); + swap_if_less(arr_ptr, 1, 12, is_less); + swap_if_less(arr_ptr, 2, 15, is_less); + swap_if_less(arr_ptr, 3, 14, is_less); + swap_if_less(arr_ptr, 4, 8, is_less); + swap_if_less(arr_ptr, 5, 6, is_less); + swap_if_less(arr_ptr, 7, 11, is_less); + swap_if_less(arr_ptr, 9, 10, is_less); + swap_if_less(arr_ptr, 0, 5, is_less); + swap_if_less(arr_ptr, 1, 7, is_less); + swap_if_less(arr_ptr, 2, 9, is_less); + swap_if_less(arr_ptr, 3, 4, is_less); + swap_if_less(arr_ptr, 6, 13, is_less); + swap_if_less(arr_ptr, 8, 14, is_less); + swap_if_less(arr_ptr, 10, 15, is_less); + swap_if_less(arr_ptr, 11, 12, is_less); + swap_if_less(arr_ptr, 0, 1, is_less); + swap_if_less(arr_ptr, 2, 3, is_less); + swap_if_less(arr_ptr, 4, 5, is_less); + swap_if_less(arr_ptr, 6, 8, is_less); + swap_if_less(arr_ptr, 7, 9, is_less); + swap_if_less(arr_ptr, 10, 11, is_less); + swap_if_less(arr_ptr, 12, 13, is_less); + swap_if_less(arr_ptr, 14, 15, is_less); + swap_if_less(arr_ptr, 0, 2, is_less); + swap_if_less(arr_ptr, 1, 3, is_less); + swap_if_less(arr_ptr, 4, 10, is_less); + swap_if_less(arr_ptr, 5, 11, is_less); + swap_if_less(arr_ptr, 6, 7, is_less); + swap_if_less(arr_ptr, 8, 9, is_less); + swap_if_less(arr_ptr, 12, 14, is_less); + swap_if_less(arr_ptr, 13, 15, is_less); + swap_if_less(arr_ptr, 1, 2, is_less); + swap_if_less(arr_ptr, 3, 12, is_less); + swap_if_less(arr_ptr, 4, 6, is_less); + swap_if_less(arr_ptr, 5, 7, is_less); + swap_if_less(arr_ptr, 8, 10, is_less); + swap_if_less(arr_ptr, 9, 11, is_less); + swap_if_less(arr_ptr, 13, 14, is_less); + swap_if_less(arr_ptr, 1, 4, is_less); + swap_if_less(arr_ptr, 2, 6, is_less); + swap_if_less(arr_ptr, 5, 8, is_less); + swap_if_less(arr_ptr, 7, 10, is_less); + swap_if_less(arr_ptr, 9, 13, is_less); + swap_if_less(arr_ptr, 11, 14, is_less); + swap_if_less(arr_ptr, 2, 4, is_less); + swap_if_less(arr_ptr, 3, 6, is_less); + swap_if_less(arr_ptr, 9, 12, is_less); + swap_if_less(arr_ptr, 11, 13, is_less); + swap_if_less(arr_ptr, 3, 5, is_less); + swap_if_less(arr_ptr, 6, 8, is_less); + swap_if_less(arr_ptr, 7, 9, is_less); + swap_if_less(arr_ptr, 10, 12, is_less); + swap_if_less(arr_ptr, 3, 4, is_less); + swap_if_less(arr_ptr, 5, 6, is_less); + swap_if_less(arr_ptr, 7, 8, is_less); + swap_if_less(arr_ptr, 9, 10, is_less); + swap_if_less(arr_ptr, 11, 12, is_less); + swap_if_less(arr_ptr, 6, 7, is_less); + swap_if_less(arr_ptr, 8, 9, is_less); + } +} + +unsafe fn sort4_plus(v: &mut [T], is_less: &mut F) +where + F: FnMut(&T, &T) -> bool, +{ + // SAFETY: caller must ensure v.len() >= 4. + unsafe { + let len = v.len(); + debug_assert!(len >= 4); + + sort4_optimal(&mut v[0..4], is_less); + insertion_sort_shift_left(v, 4, is_less); + } +} + +unsafe fn sort8_plus(v: &mut [T], is_less: &mut F) +where + F: FnMut(&T, &T) -> bool, +{ + // SAFETY: caller must ensure v.len() >= 8. + unsafe { + let len = v.len(); + debug_assert!(len >= 8); + + sort8_optimal(&mut v[0..8], is_less); + + if len >= 9 { + insertion_sort_shift_left(&mut v[8..], 1, is_less); + + // We only need place for 8 entries because we know the shorter side is at most 8 long. + let mut swap = mem::MaybeUninit::<[T; 8]>::uninit(); + let swap_ptr = swap.as_mut_ptr() as *mut T; + + merge(v, 8, swap_ptr, is_less); + } + } +} + +unsafe fn sort16_plus(v: &mut [T], is_less: &mut F) +where + F: FnMut(&T, &T) -> bool, +{ + // SAFETY: caller must ensure v.len() >= 16. + unsafe { + let len = v.len(); + debug_assert!(len >= 16); + + sort16_optimal(&mut v[0..16], is_less); + + if len >= 17 { + let start = if len >= 24 { + sort8_optimal(&mut v[16..24], is_less); + 8 + } else if len >= 20 { + sort4_optimal(&mut v[16..20], is_less); + 4 + } else { + 1 + }; + + insertion_sort_shift_left(&mut v[16..], start, is_less); + + // We only need place for 16 entries because we know the shorter side is at most 16 long. + let mut swap = mem::MaybeUninit::<[T; 16]>::uninit(); + let swap_ptr = swap.as_mut_ptr() as *mut T; + + merge(v, 16, swap_ptr, is_less); + } + } +} + +unsafe fn sort32_plus(v: &mut [T], is_less: &mut F) +where + F: FnMut(&T, &T) -> bool, +{ + // SAFETY: caller must ensure v.len() >= 32. + unsafe { + debug_assert!(v.len() >= 32 && v.len() <= 40); + + sort16_optimal(&mut v[0..16], is_less); + sort16_optimal(&mut v[16..32], is_less); + + insertion_sort_shift_left(&mut v[16..], 16, is_less); + + // We only need place for 16 entries because we know the shorter side is 16 long. + let mut swap = mem::MaybeUninit::<[T; 16]>::uninit(); + let swap_ptr = swap.as_mut_ptr() as *mut T; + + merge(v, 16, swap_ptr, is_less); + } +} + fn partition_at_index_loop<'a, T, F>( mut v: &'a mut [T], mut index: usize, @@ -833,9 +1488,13 @@ fn partition_at_index_loop<'a, T, F>( { loop { // For slices of up to this length it's probably faster to simply sort them. + + // TODO use sort_small here? const MAX_INSERTION: usize = 10; if v.len() <= MAX_INSERTION { - insertion_sort(v, is_less); + if v.len() >= 2 { + insertion_sort_shift_left(v, 1, is_less); + } return; }