Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sorting #195

Open
kernelmachine opened this issue May 6, 2016 · 3 comments
Open

sorting #195

kernelmachine opened this issue May 6, 2016 · 3 comments

Comments

@kernelmachine
Copy link

kernelmachine commented May 6, 2016

For a project I'm working on, I wrote a function that sorts a 2d matrix by the values in a particular column. Wanted to see if you were interested in generalizing this for ndarray (e.g. sort by row value, n-dimensional sorting etc).

pub fn sort_matrix(mat : Mat<f64>, sort_col: ArrayView<f64,Ix>) -> Mat<f64>{
    let mut enum_col = sort_col.iter().enumerate().collect::<Vec<(usize, &f64)>>();
    enum_col.sort_by(|a, &b| a.1.partial_cmp(b.1).unwrap());
    let indices = enum_col.iter().map(|x| x.0).collect::<Vec<usize>>();
    mat.select(Axis(0), indices.as_slice())
}
@bluss
Copy link
Member

bluss commented Apr 2, 2017

Here's a scratch approach. Permuting elements is safe because there are no user-defined calls in the critical section, so we can review it and ensure it never panics.

extern crate ndarray;

use ndarray::prelude::*;
use ndarray::{
    Data,
    RemoveAxis,
    Zip,
};

use std::cmp::Ordering;
use std::ptr::copy_nonoverlapping;

// Type invariant: Each index appears exactly once
#[derive(Clone, Debug)]
pub struct Permutation {
    indices: Vec<usize>,
}

impl Permutation {
    /// Checks if the permutation is correct
    pub fn from_indices(v: Vec<usize>) -> Result<Self, ()> {
        let perm = Permutation { indices: v };
        if perm.correct() {
            Ok(perm)
        } else {
            Err(())
        }
    }

    fn correct(&self) -> bool {
        let axis_len = self.indices.len();
        let mut seen = vec![false; axis_len];
        for &i in &self.indices {
            if seen[i] {
                return false;
            }
            seen[i] = true;
        }
        true
    }
}

pub trait SortArray {
    /// ***Panics*** if `axis` is out of bounds.
    fn identity(&self, axis: Axis) -> Permutation;
    fn sort_axis_by<F>(&self, axis: Axis, less_than: F) -> Permutation
        where F: FnMut(usize, usize) -> bool;
}

pub trait PermuteArray {
    type Elem;
    type Dim;
    fn permute_axis(self, axis: Axis, perm: &Permutation)
        -> Array<Self::Elem, Self::Dim>
        where Self::Elem: Clone, Self::Dim: RemoveAxis;
}

impl<A, S, D> SortArray for ArrayBase<S, D>
    where S: Data<Elem=A>,
          D: Dimension,
{
    fn identity(&self, axis: Axis) -> Permutation {
        Permutation {
            indices: (0..self.len_of(axis)).collect(),
        }
    }

    fn sort_axis_by<F>(&self, axis: Axis, mut less_than: F) -> Permutation
        where F: FnMut(usize, usize) -> bool
    {
        let mut perm = self.identity(axis);
        perm.indices.sort_by(move |&a, &b|
            if less_than(a, b) {
                Ordering::Less
            } else if less_than(b, a) {
                Ordering::Greater
            } else {
                Ordering::Equal
            });
        perm
    }
}

impl<A, D> PermuteArray for Array<A, D>
    where D: Dimension,
{
    type Elem = A;
    type Dim = D;

    fn permute_axis(self, axis: Axis, perm: &Permutation) -> Array<A, D>
        where D: RemoveAxis,
    {
        let axis = axis;
        let axis_len = self.len_of(axis);
        assert_eq!(axis_len, perm.indices.len());
        debug_assert!(perm.correct());

        let mut v = Vec::with_capacity(self.len());
        let mut result;

        // panic-critical begin: we must not panic
        unsafe {
            v.set_len(self.len());
            result = Array::from_shape_vec_unchecked(self.dim(), v);
            for i in 0..axis_len {
                let perm_i = perm.indices[i];
                Zip::from(result.subview_mut(axis, perm_i))
                    .and(self.subview(axis, i))
                    .apply(|to, from| {
                        copy_nonoverlapping(from, to, 1)
                    });
            }
            // forget moved array elements but not its vec
            let mut old_storage = self.into_raw_vec();
            old_storage.set_len(0);
            // old_storage drops empty
        }
        // panic-critical end
        result
    }
}


fn main() {
    let a = Array::linspace(0., 63., 64).into_shape((8, 8)).unwrap();
    let strings = a.map(|x| x.to_string());

    let perm = a.sort_axis_by(Axis(1), |i, j| {
        a[[i, 0]] > a[[j, 0]]
    });
    println!("{:?}", perm);
    let b = a.permute_axis(Axis(0), &perm);
    println!("{:?}", b);

    println!("{:?}", strings);
    let c = strings.permute_axis(Axis(1), &perm);
    println!("{:?}", c);
}

@marcbone
Copy link

marcbone commented Aug 9, 2019

I hope that this will get implemented soon. An argsort method would also be very useful

@dam5h
Copy link
Contributor

dam5h commented Feb 9, 2021

I tried using this example but found that it didn't correctly sort my medium sized array, but worked fine on smaller arrays. I made a DRAFT PR to show some tests that I created to illustrate the phenomenon. Not meant to be a PR, rather just to share the test cases in case of interest.

#916

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants