Skip to content

Commit

Permalink
Add tensor sorting operations (#1488)
Browse files Browse the repository at this point in the history
* Add sort, sort_with_indices and argsort

* Fix wasm build

* Add sort ops autodiff

* Fix TODO parallel comment

* Fix sort_with_indices 1d and add descending options

* Fix clippy

* Fix ascending comment (configurable)
  • Loading branch information
laggui authored Mar 20, 2024
1 parent 430f642 commit 47a84cc
Show file tree
Hide file tree
Showing 23 changed files with 1,525 additions and 19 deletions.
6 changes: 6 additions & 0 deletions burn-book/src/building-blocks/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,12 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
| `tensor.sum_dim(dim)` | `tensor.sum(dim, keepdim=True)` |
| `tensor.tril(diagonal)` | `torch.tril(tensor, diagonal)` |
| `tensor.triu(diagonal)` | `torch.triu(tensor, diagonal)` |
| `tensor.sort(dim)` | `tensor.sort(dim).values` |
| `tensor.sort_descending(dim)` | `tensor.sort(dim, descending=True).values` |
| `tensor.sort_with_indices(dim)` | `tensor.sort(dim)` |
| `tensor.sort_descending_with_indices(dim)` | `tensor.sort(dim, descending=True)` |
| `tensor.argsort(dim)` | `tensor.argsort(dim)` |
| `tensor.argsort_descending(dim)` | `tensor.argsort(dim, descending=True)` |

### Float Operations

Expand Down
24 changes: 24 additions & 0 deletions crates/burn-autodiff/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,4 +368,28 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
fn int_prod_dim<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
B::int_prod_dim(tensor, dim)
}

fn int_sort<const D: usize>(
tensor: IntTensor<Self, D>,
dim: usize,
descending: bool,
) -> IntTensor<Self, D> {
B::int_sort(tensor, dim, descending)
}

fn int_sort_with_indices<const D: usize>(
tensor: IntTensor<Self, D>,
dim: usize,
descending: bool,
) -> (IntTensor<Self, D>, IntTensor<Self, D>) {
B::int_sort_with_indices(tensor, dim, descending)
}

fn int_argsort<const D: usize>(
tensor: IntTensor<Self, D>,
dim: usize,
descending: bool,
) -> IntTensor<Self, D> {
B::int_argsort(tensor, dim, descending)
}
}
1 change: 1 addition & 0 deletions crates/burn-autodiff/src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod module;
mod tensor;

pub(crate) mod maxmin;
pub(crate) mod sort;

pub use backward::*;
pub use base::*;
25 changes: 25 additions & 0 deletions crates/burn-autodiff/src/ops/sort.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use super::{unary, Backward, Ops};
use crate::{checkpoint::base::Checkpointer, grads::Gradients};
use burn_tensor::{backend::Backend, Shape};

#[derive(Debug)]
pub(crate) struct SortDim;

impl<B: Backend, const D: usize> Backward<B, D, 1> for SortDim {
type State = (B::IntTensorPrimitive<D>, Shape<D>);

fn backward(
self,
ops: Ops<Self::State, 1>,
grads: &mut Gradients,
_checkpointer: &mut Checkpointer,
) {
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
let (indices, shape) = ops.state;
let device = B::float_device(&grad);
let zeros = B::float_zeros(shape, &device);

B::float_scatter(D - 1, zeros, indices, grad)
});
}
}
59 changes: 59 additions & 0 deletions crates/burn-autodiff/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use burn_tensor::{
};

use super::maxmin::MaxMinDim;
use super::sort::SortDim;

impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C> {
fn float_from_data<const D: usize>(
Expand Down Expand Up @@ -2436,6 +2437,64 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
.stateless(B::float_sign(tensor.primitive))
}

fn float_sort<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
descending: bool,
) -> FloatTensor<Self, D> {
match SortDim
.prepare::<C>([tensor.node], [tensor.graph])
.compute_bound()
.stateful()
{
OpsKind::Tracked(prep) => {
let shape = B::float_shape(&tensor.primitive);
let (tensor, indices) =
B::float_sort_with_indices(tensor.primitive, dim, descending);
prep.finish((indices, shape), tensor)
}
OpsKind::UnTracked(prep) => {
prep.finish(B::float_sort(tensor.primitive, dim, descending))
}
}
}

fn float_sort_with_indices<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
descending: bool,
) -> (FloatTensor<Self, D>, IntTensor<B, D>) {
match SortDim
.prepare::<C>([tensor.node], [tensor.graph])
.compute_bound()
.stateful()
{
OpsKind::Tracked(prep) => {
let shape = B::float_shape(&tensor.primitive);
let (tensor, indices) =
B::float_sort_with_indices(tensor.primitive, dim, descending);
let tensor = prep.finish((indices.clone(), shape), tensor);

(tensor, indices)
}
OpsKind::UnTracked(prep) => {
let (tensor, indices) =
B::float_sort_with_indices(tensor.primitive, dim, descending);
let tensor = prep.finish(tensor);

(tensor, indices)
}
}
}

fn float_argsort<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
descending: bool,
) -> IntTensor<B, D> {
B::float_argsort(tensor.primitive, dim, descending)
}

// TODO: Implement float_prod and float_sum
// https://github.com/tracel-ai/burn/issues/1458
}
Expand Down
2 changes: 2 additions & 0 deletions crates/burn-autodiff/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ mod sign;
mod sin;
mod slice;
mod softmax;
mod sort;
mod sqrt;
mod sub;
mod tanh;
Expand Down Expand Up @@ -118,5 +119,6 @@ macro_rules! testgen_all {
burn_autodiff::testgen_ad_flip!();
burn_autodiff::testgen_ad_nonzero!();
burn_autodiff::testgen_ad_sign!();
burn_autodiff::testgen_ad_sort!();
};
}
51 changes: 51 additions & 0 deletions crates/burn-autodiff/src/tests/sort.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#[burn_tensor_testgen::testgen(ad_sort)]
mod tests {
use super::*;
use burn_tensor::Data;

#[test]
fn should_diff_sort() {
let device = Default::default();
let tensor_1 =
TestAutodiffTensor::from_floats([[1.0, 7.0], [-2.0, -3.0]], &device).require_grad();
let tensor_2 =
TestAutodiffTensor::from_floats([[4.0, -7.0], [2.0, 3.0]], &device).require_grad();

let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_1.clone().mul(tensor_3.sort(1));
let grads = tensor_4.backward();

let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();

grad_1
.to_data()
.assert_approx_eq(&Data::from([[35.0, 35.0], [-1.0, -8.0]]), 5);
grad_2
.to_data()
.assert_approx_eq(&Data::from([[11.0, 7.0], [55.0, 16.0]]), 5);
}

#[test]
fn should_diff_sort_with_indices() {
let device = Default::default();
let tensor_1 =
TestAutodiffTensor::from_floats([[1.0, 7.0], [-2.0, -3.0]], &device).require_grad();
let tensor_2 =
TestAutodiffTensor::from_floats([[4.0, -7.0], [2.0, 3.0]], &device).require_grad();

let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let tensor_4 = tensor_1.clone().mul(tensor_3.sort_with_indices(1).0);
let grads = tensor_4.backward();

let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();

grad_1
.to_data()
.assert_approx_eq(&Data::from([[35.0, 35.0], [-1.0, -8.0]]), 5);
grad_2
.to_data()
.assert_approx_eq(&Data::from([[11.0, 7.0], [55.0, 16.0]]), 5);
}
}
16 changes: 16 additions & 0 deletions crates/burn-tch/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -494,4 +494,20 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
pub fn sign<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
tensor.unary_ops(|mut tensor| tensor.sign_(), |tensor| tensor.sign())
}

pub fn sort<const D: usize>(
tensor: TchTensor<E, D>,
dim: usize,
descending: bool,
) -> TchTensor<E, D> {
TchTensor::new(tensor.tensor.sort(dim as i64, descending).0)
}

pub fn argsort<const D: usize>(
tensor: TchTensor<E, D>,
dim: usize,
descending: bool,
) -> TchTensor<i64, D> {
TchTensor::new(tensor.tensor.argsort(dim as i64, descending))
}
}
16 changes: 16 additions & 0 deletions crates/burn-tch/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -486,4 +486,20 @@ impl<E: TchElement> IntTensorOps<Self> for LibTorch<E> {
) -> <LibTorch<E> as Backend>::IntTensorPrimitive<D> {
TchOps::sign(tensor)
}

fn int_sort<const D: usize>(
tensor: <LibTorch<E> as Backend>::IntTensorPrimitive<D>,
dim: usize,
descending: bool,
) -> <LibTorch<E> as Backend>::IntTensorPrimitive<D> {
TchOps::sort(tensor, dim, descending)
}

fn int_argsort<const D: usize>(
tensor: <LibTorch<E> as Backend>::IntTensorPrimitive<D>,
dim: usize,
descending: bool,
) -> <LibTorch<E> as Backend>::IntTensorPrimitive<D> {
TchOps::argsort(tensor, dim, descending)
}
}
16 changes: 16 additions & 0 deletions crates/burn-tch/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -495,4 +495,20 @@ impl<E: TchElement> FloatTensorOps<Self> for LibTorch<E> {
) -> <LibTorch<E> as Backend>::FloatTensorPrimitive<D> {
TchOps::sign(tensor)
}

fn float_sort<const D: usize>(
tensor: <LibTorch<E> as Backend>::FloatTensorPrimitive<D>,
dim: usize,
descending: bool,
) -> <LibTorch<E> as Backend>::FloatTensorPrimitive<D> {
TchOps::sort(tensor, dim, descending)
}

fn float_argsort<const D: usize>(
tensor: <LibTorch<E> as Backend>::FloatTensorPrimitive<D>,
dim: usize,
descending: bool,
) -> <LibTorch<E> as Backend>::IntTensorPrimitive<D> {
TchOps::argsort(tensor, dim, descending)
}
}
15 changes: 15 additions & 0 deletions crates/burn-tensor/src/tensor/api/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,21 @@ impl TensorCheck {
check
}

pub(crate) fn sort_dim<const D: usize>(ops: &str, dim: usize) -> Self {
let mut check = Self::Ok;

if dim > D {
check = check.register(
ops,
TensorError::new(format!(
"Can't sort a tensor with ({D}) dimensions on axis ({dim})"
)),
);
}

check
}

/// The goal is to minimize the cost of checks when there are no error, but it's way less
/// important when an error occurred, crafting a comprehensive error message is more important
/// than optimizing string manipulation.
Expand Down
32 changes: 32 additions & 0 deletions crates/burn-tensor/src/tensor/api/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ use crate::tensor::{Data, Distribution, Shape};
use crate::Int;
use crate::Tensor;

#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
use crate::{argsort, sort, sort_with_indices, Float};

impl<const D: usize, B> Tensor<B, D>
where
B: Backend,
Expand Down Expand Up @@ -261,4 +264,33 @@ where
.matmul(centered)
.div_scalar(n as f32 - correction_factor as f32)
}

/// Sort the elements by value in ascending order along a given dimension.
///
/// This sort is unstable (i.e., may reorder equal elements).
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
pub async fn sort(self, dim: usize) -> Tensor<B, D> {
Tensor::new(sort::<B, D, Float>(self.primitive, dim, /*descending*/ false).await)
}

/// Sort the elements by value in ascending order along a given dimension.
/// Also returns the indices.
///
/// This sort is unstable (i.e., may reorder equal elements).
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
pub async fn sort_with_indices(self, dim: usize) -> (Tensor<B, D>, Tensor<B, D, Int>) {
check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
let (values, indices) =
sort_with_indices::<B, D, Float>(self.primitive, dim, /*descending*/ false).await;
(Tensor::new(values), Tensor::new(indices))
}

/// Returns the indices that sort the elements by value in ascending order along a given dimension.
///
/// This sort is unstable (i.e., may reorder equal elements).
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
pub async fn argsort(self, dim: usize) -> Tensor<B, D, Int> {
check!(TensorCheck::sort_dim::<D>("Argsort", dim));
Tensor::new(argsort::<B, D, Float>(self.primitive, dim, /*descending*/ false).await)
}
}
36 changes: 36 additions & 0 deletions crates/burn-tensor/src/tensor/api/int.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use crate::{backend::Backend, Data, Float, Int, Tensor};
use core::ops::Range;

#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
use crate::{argsort, check, check::TensorCheck, sort, sort_with_indices};

impl<B> Tensor<B, 1, Int>
where
B: Backend,
Expand Down Expand Up @@ -66,4 +69,37 @@ where
pub fn float(self) -> Tensor<B, D, Float> {
Tensor::new(B::int_into_float(self.primitive))
}

/// Sort the elements by value along a given dimension.
///
/// This sort is unstable (i.e., may reorder equal elements).
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
pub async fn sort(self, dim: usize, descending: bool) -> Tensor<B, D, Int> {
Tensor::new(sort::<B, D, Int>(self.primitive, dim, descending).await)
}

/// Sort the elements by value along a given dimension.
/// Also returns the indices.
///
/// This sort is unstable (i.e., may reorder equal elements).
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
pub async fn sort_with_indices(
self,
dim: usize,
descending: bool,
) -> (Tensor<B, D, Int>, Tensor<B, D, Int>) {
check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
let (values, indices) =
sort_with_indices::<B, D, Int>(self.primitive, dim, descending).await;
(Tensor::new(values), Tensor::new(indices))
}

/// Returns the indices that sort the elements by value along a given dimension.
///
/// This sort is unstable (i.e., may reorder equal elements).
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
pub async fn argsort(self, dim: usize, descending: bool) -> Tensor<B, D, Int> {
check!(TensorCheck::sort_dim::<D>("Argsort", dim));
Tensor::new(argsort::<B, D, Int>(self.primitive, dim, descending).await)
}
}
Loading

0 comments on commit 47a84cc

Please sign in to comment.