Skip to content

Commit

Permalink
Adding _by, by_key, largest variants of k_smallest
Browse files Browse the repository at this point in the history
  • Loading branch information
ejmount authored and Philippe-Cholet committed Feb 26, 2024
1 parent 04e13c1 commit 16ce601
Show file tree
Hide file tree
Showing 3 changed files with 226 additions and 31 deletions.
104 changes: 89 additions & 15 deletions src/k_smallest.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,96 @@
use alloc::collections::BinaryHeap;
use core::cmp::Ord;
use alloc::vec::Vec;
use core::cmp::Ordering;

/// Consumes a given iterator, returning the minimum elements in **ascending** order.
pub(crate) fn k_smallest_general<I, F>(mut iter: I, k: usize, mut comparator: F) -> Vec<I::Item>
where
I: Iterator,
F: FnMut(&I::Item, &I::Item) -> Ordering,
{
/// Sift the element currently at `origin` away from the root until it is properly ordered.
///
/// This will leave **larger** elements closer to the root of the heap.
fn sift_down<T, F>(heap: &mut [T], is_less_than: &mut F, mut origin: usize)
where
F: FnMut(&T, &T) -> bool,
{
#[inline]
fn children_of(n: usize) -> (usize, usize) {
(2 * n + 1, 2 * n + 2)
}

while origin < heap.len() {
let (left_idx, right_idx) = children_of(origin);
if left_idx >= heap.len() {
return;
}

let replacement_idx =
if right_idx < heap.len() && is_less_than(&heap[left_idx], &heap[right_idx]) {
right_idx
} else {
left_idx
};

if is_less_than(&heap[origin], &heap[replacement_idx]) {
heap.swap(origin, replacement_idx);
origin = replacement_idx;
} else {
return;
}
}
}

pub(crate) fn k_smallest<T: Ord, I: Iterator<Item = T>>(mut iter: I, k: usize) -> BinaryHeap<T> {
if k == 0 {
return BinaryHeap::new();
return Vec::new();
}
let mut storage: Vec<I::Item> = iter.by_ref().take(k).collect();

let mut heap = iter.by_ref().take(k).collect::<BinaryHeap<_>>();
let mut is_less_than = move |a: &_, b: &_| comparator(a, b) == Ordering::Less;

iter.for_each(|i| {
debug_assert_eq!(heap.len(), k);
// Equivalent to heap.push(min(i, heap.pop())) but more efficient.
// This should be done with a single `.peek_mut().unwrap()` but
// `PeekMut` sifts-down unconditionally on Rust 1.46.0 and prior.
if *heap.peek().unwrap() > i {
*heap.peek_mut().unwrap() = i;
}
});
// Rearrange the storage into a valid heap by reordering from the second-bottom-most layer up to the root.
// Slightly faster than ordering on each insert, but only by a factor of lg(k).
// The resulting heap has the **largest** item on top.
for i in (0..=(storage.len() / 2)).rev() {
sift_down(&mut storage, &mut is_less_than, i);
}

if k == storage.len() {
// If we fill the storage, there may still be iterator elements left so feed them into the heap.
// Also avoids unexpected behaviour with restartable iterators.
iter.for_each(|val| {
if is_less_than(&val, &storage[0]) {
// Treating this as an push-and-pop saves having to write a sift-up implementation.
// https://en.wikipedia.org/wiki/Binary_heap#Insert_then_extract
storage[0] = val;
// We retain the smallest items we've seen so far, but ordered largest first so we can drop the largest efficiently.
sift_down(&mut storage, &mut is_less_than, 0);
}
});
}

// Ultimately the items need to be in least-first, strict order, but the heap is currently largest-first.
// To achieve this, repeatedly,
// 1) "pop" the largest item off the heap into the tail slot of the underlying storage,
// 2) shrink the logical size of the heap by 1,
// 3) restore the heap property over the remaining items.
let mut heap = &mut storage[..];
while heap.len() > 1 {
let last_idx = heap.len() - 1;
heap.swap(0, last_idx);
// Sifting over a truncated slice means that the sifting will not disturb already popped elements.
heap = &mut heap[..last_idx];
sift_down(heap, &mut is_less_than, 0);
}

storage
}

heap
#[inline]
pub(crate) fn key_to_cmp<T, K, F>(key: F) -> impl Fn(&T, &T) -> Ordering
where
F: Fn(&T) -> K,
K: Ord,
{
move |a, b| key(a).cmp(&key(b))
}
102 changes: 98 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2950,14 +2950,108 @@ pub trait Itertools: Iterator {
/// itertools::assert_equal(five_smallest, 0..5);
/// ```
#[cfg(feature = "use_alloc")]
fn k_smallest(self, k: usize) -> VecIntoIter<Self::Item>
fn k_smallest(mut self, k: usize) -> VecIntoIter<Self::Item>
where
Self: Sized,
Self::Item: Ord,
{
crate::k_smallest::k_smallest(self, k)
.into_sorted_vec()
.into_iter()
// The stdlib heap has optimised handling of "holes", which is not included in our heap implementation in k_smallest_general.
// While the difference is unlikely to have practical impact unless `Self::Item` is very large, this method uses the stdlib structure
// to maintain performance compared to previous versions of the crate.
use alloc::collections::BinaryHeap;

if k == 0 {
return Vec::new().into_iter();
}

let mut heap = self.by_ref().take(k).collect::<BinaryHeap<_>>();

self.for_each(|i| {
debug_assert_eq!(heap.len(), k);
// Equivalent to heap.push(min(i, heap.pop())) but more efficient.
// This should be done with a single `.peek_mut().unwrap()` but
// `PeekMut` sifts-down unconditionally on Rust 1.46.0 and prior.
if *heap.peek().unwrap() > i {
*heap.peek_mut().unwrap() = i;
}
});

heap.into_sorted_vec().into_iter()
}

/// Sort the k smallest elements into a new iterator using the provided comparison.
///
/// This corresponds to `self.sorted_by(cmp).take(k)` in the same way that
/// [Itertools::k_smallest] corresponds to `self.sorted().take(k)`, in both semantics and complexity.
/// Particularly, a custom heap implementation ensures the comparison is not cloned.
#[cfg(feature = "use_alloc")]
fn k_smallest_by<F>(self, k: usize, cmp: F) -> VecIntoIter<Self::Item>
where
Self: Sized,
F: Fn(&Self::Item, &Self::Item) -> Ordering,
{
k_smallest::k_smallest_general(self, k, cmp).into_iter()
}

/// Return the elements producing the k smallest outputs of the provided function
///
/// This corresponds to `self.sorted_by_key(cmp).take(k)` in the same way that
/// [Itertools::k_smallest] corresponds to `self.sorted().take(k)`, in both semantics and time complexity.
#[cfg(feature = "use_alloc")]
fn k_smallest_by_key<F, K>(self, k: usize, key: F) -> VecIntoIter<Self::Item>
where
Self: Sized,
F: Fn(&Self::Item) -> K,
K: Ord,
{
self.k_smallest_by(k, k_smallest::key_to_cmp(key))
}

/// Sort the k largest elements into a new iterator, in descending order.
/// Semantically equivalent to `k_smallest` with a reversed `Ord`
/// However, this is implemented by way of a custom binary heap
/// which does not have the same performance characteristics for very large `Self::Item`
/// ```
/// use itertools::Itertools;
///
/// // A random permutation of 0..15
/// let numbers = vec![6, 9, 1, 14, 0, 4, 8, 7, 11, 2, 10, 3, 13, 12, 5];
///
/// let five_largest = numbers
/// .into_iter()
/// .k_largest(5);
///
/// itertools::assert_equal(five_largest, vec![14,13,12,11,10]);
/// ```
#[cfg(feature = "use_alloc")]
fn k_largest(self, k: usize) -> VecIntoIter<Self::Item>
where
Self: Sized,
Self::Item: Ord,
{
self.k_largest_by(k, Self::Item::cmp)
}

/// Sort the k largest elements into a new iterator using the provided comparison.
/// Functionally equivalent to `k_smallest_by` with a reversed `Ord`
#[cfg(feature = "use_alloc")]
fn k_largest_by<F>(self, k: usize, cmp: F) -> VecIntoIter<Self::Item>
where
Self: Sized,
F: Fn(&Self::Item, &Self::Item) -> Ordering,
{
self.k_smallest_by(k, move |a, b| cmp(b, a))
}

/// Return the elements producing the k largest outputs of the provided function
#[cfg(feature = "use_alloc")]
fn k_largest_by_key<F, K>(self, k: usize, key: F) -> VecIntoIter<Self::Item>
where
Self: Sized,
F: Fn(&Self::Item) -> K,
K: Ord,
{
self.k_largest_by(k, k_smallest::key_to_cmp(key))
}

/// Collect all iterator elements into one of two
Expand Down
51 changes: 39 additions & 12 deletions tests/test_std.rs
Original file line number Diff line number Diff line change
Expand Up @@ -492,23 +492,50 @@ fn sorted_by() {
}

qc::quickcheck! {
fn k_smallest_range(n: u64, m: u16, k: u16) -> () {
fn k_smallest_range(n: i64, m: u16, k: u16) -> () {
// u16 is used to constrain k and m to 0..2¹⁶,
// otherwise the test could use too much memory.
let (k, m) = (k as u64, m as u64);
let (k, m) = (k as usize, m as u64);

let mut v: Vec<_> = (n..n.saturating_add(m as _)).collect();
// Generate a random permutation of n..n+m
let i = {
let mut v: Vec<u64> = (n..n.saturating_add(m)).collect();
v.shuffle(&mut thread_rng());
v.into_iter()
};
v.shuffle(&mut thread_rng());

// Construct the right answers for the top and bottom elements
let mut sorted = v.clone();
sorted.sort();
// how many elements are we checking
let num_elements = min(k, m as _);

// Compute the top and bottom k in various combinations
let smallest = v.iter().cloned().k_smallest(k);
let smallest_by = v.iter().cloned().k_smallest_by(k, Ord::cmp);
let smallest_by_key = v.iter().cloned().k_smallest_by_key(k, |&x| x);

let largest = v.iter().cloned().k_largest(k);
let largest_by = v.iter().cloned().k_largest_by(k, Ord::cmp);
let largest_by_key = v.iter().cloned().k_largest_by_key(k, |&x| x);

// Check the variations produce the same answers and that they're right
for (a,b,c,d) in izip!(
sorted[..num_elements].iter().cloned(),
smallest,
smallest_by,
smallest_by_key) {
assert_eq!(a,b);
assert_eq!(a,c);
assert_eq!(a,d);
}

// Check that taking the k smallest elements yields n..n+min(k, m)
it::assert_equal(
i.k_smallest(k as usize),
n..n.saturating_add(min(k, m))
);
for (a,b,c,d) in izip!(
sorted[sorted.len()-num_elements..].iter().rev().cloned(),
largest,
largest_by,
largest_by_key) {
assert_eq!(a,b);
assert_eq!(a,c);
assert_eq!(a,d);
}
}
}

Expand Down

0 comments on commit 16ce601

Please sign in to comment.