-
Notifications
You must be signed in to change notification settings - Fork 308
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
sorting #195
Comments
Here's a scratch approach. Permuting elements is safe because there are no user-defined calls in the critical section, so we can review it and ensure it never panics. extern crate ndarray;
use ndarray::prelude::*;
use ndarray::{
Data,
RemoveAxis,
Zip,
};
use std::cmp::Ordering;
use std::ptr::copy_nonoverlapping;
// Type invariant: Each index appears exactly once
#[derive(Clone, Debug)]
pub struct Permutation {
indices: Vec<usize>,
}
impl Permutation {
/// Checks if the permutation is correct
pub fn from_indices(v: Vec<usize>) -> Result<Self, ()> {
let perm = Permutation { indices: v };
if perm.correct() {
Ok(perm)
} else {
Err(())
}
}
fn correct(&self) -> bool {
let axis_len = self.indices.len();
let mut seen = vec![false; axis_len];
for &i in &self.indices {
if seen[i] {
return false;
}
seen[i] = true;
}
true
}
}
pub trait SortArray {
/// ***Panics*** if `axis` is out of bounds.
fn identity(&self, axis: Axis) -> Permutation;
fn sort_axis_by<F>(&self, axis: Axis, less_than: F) -> Permutation
where F: FnMut(usize, usize) -> bool;
}
pub trait PermuteArray {
type Elem;
type Dim;
fn permute_axis(self, axis: Axis, perm: &Permutation)
-> Array<Self::Elem, Self::Dim>
where Self::Elem: Clone, Self::Dim: RemoveAxis;
}
impl<A, S, D> SortArray for ArrayBase<S, D>
where S: Data<Elem=A>,
D: Dimension,
{
fn identity(&self, axis: Axis) -> Permutation {
Permutation {
indices: (0..self.len_of(axis)).collect(),
}
}
fn sort_axis_by<F>(&self, axis: Axis, mut less_than: F) -> Permutation
where F: FnMut(usize, usize) -> bool
{
let mut perm = self.identity(axis);
perm.indices.sort_by(move |&a, &b|
if less_than(a, b) {
Ordering::Less
} else if less_than(b, a) {
Ordering::Greater
} else {
Ordering::Equal
});
perm
}
}
impl<A, D> PermuteArray for Array<A, D>
where D: Dimension,
{
type Elem = A;
type Dim = D;
fn permute_axis(self, axis: Axis, perm: &Permutation) -> Array<A, D>
where D: RemoveAxis,
{
let axis = axis;
let axis_len = self.len_of(axis);
assert_eq!(axis_len, perm.indices.len());
debug_assert!(perm.correct());
let mut v = Vec::with_capacity(self.len());
let mut result;
// panic-critical begin: we must not panic
unsafe {
v.set_len(self.len());
result = Array::from_shape_vec_unchecked(self.dim(), v);
for i in 0..axis_len {
let perm_i = perm.indices[i];
Zip::from(result.subview_mut(axis, perm_i))
.and(self.subview(axis, i))
.apply(|to, from| {
copy_nonoverlapping(from, to, 1)
});
}
// forget moved array elements but not its vec
let mut old_storage = self.into_raw_vec();
old_storage.set_len(0);
// old_storage drops empty
}
// panic-critical end
result
}
}
fn main() {
let a = Array::linspace(0., 63., 64).into_shape((8, 8)).unwrap();
let strings = a.map(|x| x.to_string());
let perm = a.sort_axis_by(Axis(1), |i, j| {
a[[i, 0]] > a[[j, 0]]
});
println!("{:?}", perm);
let b = a.permute_axis(Axis(0), &perm);
println!("{:?}", b);
println!("{:?}", strings);
let c = strings.permute_axis(Axis(1), &perm);
println!("{:?}", c);
} |
I hope that this will get implemented soon. An argsort method would also be very useful |
I tried using this example but found that it didn't correctly sort my medium sized array, but worked fine on smaller arrays. I made a DRAFT PR to show some tests that I created to illustrate the phenomenon. Not meant to be a PR, rather just to share the test cases in case of interest. |
For a project I'm working on, I wrote a function that sorts a 2d matrix by the values in a particular column. Wanted to see if you were interested in generalizing this for ndarray (e.g. sort by row value, n-dimensional sorting etc).
The text was updated successfully, but these errors were encountered: