Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tensor sorting operations #1488

Merged
merged 7 commits into from
Mar 20, 2024
Merged

Add tensor sorting operations #1488

merged 7 commits into from
Mar 20, 2024

Conversation

laggui
Copy link
Member

@laggui laggui commented Mar 18, 2024

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

Closes #1369
Unblocks easy implementation for topk issue #1421

Changes

Added sort(dim), sort_with_indices(dim) and argsort(dim) numeric tensor ops.
Also added sort_descending(dim), sort_descending_with_indices(dim) and argsort_descending(dim) for sorting in descending order.

  • Added ElementComparison trait for Element so elements can return an ordering (as required by sort methods)
  • Implemented default tensor ops for all backends
  • Implemented for tch backend
  • Added autodiff backward

TODO

  • Parallelize default implementation loop with rayon
    • Note: we have run_par, iter_par and iter_range_par macros defined in burn-ndarray

Testing

Forward and backward unit tests.

@laggui laggui force-pushed the feat/tensor-sort-argsort branch from 944c1c3 to 12c0f67 Compare March 18, 2024 18:05
Copy link

codecov bot commented Mar 18, 2024

Codecov Report

Attention: Patch coverage is 89.27336% with 93 lines in your changes are missing coverage. Please review.

Project coverage is 85.86%. Comparing base (6e58663) to head (d65f32f).
Report is 3 commits behind head on main.

Files Patch % Lines
crates/burn-autodiff/src/ops/int_tensor.rs 0.00% 20 Missing ⚠️
crates/burn-tensor/src/tensor/element.rs 38.46% 16 Missing ⚠️
crates/burn-tch/src/ops/base.rs 0.00% 13 Missing ⚠️
crates/burn-tch/src/ops/int_tensor.rs 0.00% 13 Missing ⚠️
crates/burn-tch/src/ops/tensor.rs 0.00% 13 Missing ⚠️
crates/burn-tensor/src/tensor/api/numeric.rs 88.40% 8 Missing ⚠️
crates/burn-tensor/src/tests/ops/sort_argsort.rs 97.34% 8 Missing ⚠️
crates/burn-autodiff/src/ops/sort.rs 93.33% 1 Missing ⚠️
crates/burn-autodiff/src/ops/tensor.rs 98.11% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1488      +/-   ##
==========================================
+ Coverage   85.80%   85.86%   +0.06%     
==========================================
  Files         650      659       +9     
  Lines       72524    74082    +1558     
==========================================
+ Hits        62229    63612    +1383     
- Misses      10295    10470     +175     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@laggui laggui marked this pull request as ready for review March 19, 2024 13:41
Copy link
Member

@nathanielsimard nathanielsimard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a ticket to remove the async version of those methods, since it adds complexity and is only viable because we don't have a wgpu implementation yet. We could call read_sync or an equivalent in the default implementation.

Comment on lines +272 to +292
pub async fn sort(self, dim: usize) -> Tensor<B, D> {
Tensor::new(sort::<B, D, Float>(self.primitive, dim, /*descending*/ false).await)
}

/// Sort the elements by value in ascending order along a given dimension.
/// Also returns the indices.
///
/// This sort is unstable (i.e., may reorder equal elements).
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
pub async fn sort_with_indices(self, dim: usize) -> (Tensor<B, D>, Tensor<B, D, Int>) {
check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
let (values, indices) =
sort_with_indices::<B, D, Float>(self.primitive, dim, /*descending*/ false).await;
(Tensor::new(values), Tensor::new(indices))
}

/// Returns the indices that sort the elements by value in ascending order along a given dimension.
///
/// This sort is unstable (i.e., may reorder equal elements).
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
pub async fn argsort(self, dim: usize) -> Tensor<B, D, Int> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we put those into numeric to avoid duplication? With newer version of Rust, we can add async in traits.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I guess I should update my rust version, it wouldn't let me do it initially.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevermind it was a warning not an error 🤦

use of `async fn` in public traits is discouraged as auto trait bounds cannot be specified

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will leave it like this for now to avoid the warnings. This will be temporary anyways with #1490.

crates/burn-tensor/src/tensor/api/numeric.rs Outdated Show resolved Hide resolved
Comment on lines +62 to +65
pub trait ElementComparison {
/// Returns and [Ordering] between `self` and `other`.
fn cmp(&self, other: &Self) -> Ordering;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect 🔥

@laggui laggui self-assigned this Mar 19, 2024
@laggui laggui changed the title Feat/tensor sort argsort Add tensor sorting operations Mar 19, 2024
@laggui laggui requested a review from nathanielsimard March 19, 2024 16:54
@laggui laggui merged commit 47a84cc into main Mar 20, 2024
15 checks passed
@laggui laggui deleted the feat/tensor-sort-argsort branch March 20, 2024 18:51
@laggui laggui mentioned this pull request Mar 20, 2024
1 task
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

Add argsort tensor op
2 participants